arahrooh commited on
Commit
ecb9bc5
·
1 Parent(s): 45ed548

Use Hugging Face Inference API on Spaces instead of loading models locally

Browse files
Files changed (1) hide show
  1. app.py +193 -23
app.py CHANGED
@@ -23,13 +23,21 @@ import gradio as gr
23
  import argparse
24
  import sys
25
  import os
26
- from typing import Tuple, Optional
27
  import logging
28
  import textstat
29
  import torch
30
 
31
  # Import from bot.py
32
- from bot import RAGBot, parse_args
 
 
 
 
 
 
 
 
33
 
34
  # Set up logging
35
  logging.basicConfig(level=logging.INFO)
@@ -159,12 +167,151 @@ EXAMPLE_QUESTIONS = [
159
  ]
160
 
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  class GradioRAGInterface:
163
  """Wrapper class to integrate RAGBot with Gradio"""
164
 
165
- def __init__(self, initial_bot: RAGBot):
166
- self.bot = initial_bot
167
- self.current_model = initial_bot.args.model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  self.data_dir = initial_bot.args.data_dir
169
  logger.info("GradioRAGInterface initialized")
170
 
@@ -194,22 +341,29 @@ class GradioRAGInterface:
194
  return f"Model already loaded: {model_short_name}"
195
 
196
  try:
197
- logger.info(f"Reloading model from {self.current_model} to {new_model_path}")
198
-
199
- # Update args
200
- self.bot.args.model = new_model_path
201
 
202
- # Clear old model from memory
203
- if self.bot.model is not None:
204
- del self.bot.model
205
- del self.bot.tokenizer
206
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
207
-
208
- # Load new model
209
- self.bot._load_model()
210
- self.current_model = new_model_path
211
-
212
- return f"✓ Model loaded: {model_short_name}"
 
 
 
 
 
 
 
 
 
 
213
  except Exception as e:
214
  logger.error(f"Error reloading model: {e}", exc_info=True)
215
  return f"✗ Error loading model: {str(e)}"
@@ -394,10 +548,14 @@ SOURCE {i+1} | Similarity: {score:.3f}
394
  )
395
 
396
 
397
- def create_interface(initial_bot: RAGBot) -> gr.Blocks:
398
  """Create and configure the Gradio interface"""
399
 
400
- interface = GradioRAGInterface(initial_bot)
 
 
 
 
401
 
402
  # Get initial model name from bot
403
  initial_model_short = None
@@ -687,8 +845,20 @@ def create_demo_for_spaces():
687
  parser.add_argument('--seed', type=int, default=42)
688
 
689
  args = parser.parse_args([]) # Empty args for Spaces
 
 
 
 
 
690
  bot = RAGBot(args)
691
- return create_interface(bot)
 
 
 
 
 
 
 
692
  except Exception as e:
693
  logger.error(f"Error creating demo for Spaces: {e}", exc_info=True)
694
  # Return a simple error demo
 
23
  import argparse
24
  import sys
25
  import os
26
+ from typing import Tuple, Optional, List
27
  import logging
28
  import textstat
29
  import torch
30
 
31
  # Import from bot.py
32
+ from bot import RAGBot, parse_args, Chunk
33
+
34
+ # For Hugging Face Inference API
35
+ try:
36
+ from huggingface_hub import InferenceClient
37
+ HF_INFERENCE_AVAILABLE = True
38
+ except ImportError:
39
+ HF_INFERENCE_AVAILABLE = False
40
+ logger.warning("huggingface_hub not available, InferenceClient will not work")
41
 
42
  # Set up logging
43
  logging.basicConfig(level=logging.INFO)
 
167
  ]
168
 
169
 
170
+ class InferenceAPIBot:
171
+ """Wrapper that uses Hugging Face Inference API instead of loading models locally"""
172
+
173
+ def __init__(self, bot: RAGBot, hf_token: str):
174
+ """Initialize with a RAGBot (for vector DB) and HF token for Inference API"""
175
+ self.bot = bot # Use bot for vector DB and formatting
176
+ self.client = InferenceClient(api_key=hf_token)
177
+ self.current_model = bot.args.model
178
+ logger.info(f"InferenceAPIBot initialized with model: {self.current_model}")
179
+
180
+ def generate_answer(self, prompt: str, **kwargs) -> str:
181
+ """Generate answer using Inference API"""
182
+ try:
183
+ # Convert prompt to chat format
184
+ messages = [{"role": "user", "content": prompt}]
185
+
186
+ # Call Inference API
187
+ completion = self.client.chat.completions.create(
188
+ model=self.current_model,
189
+ messages=messages,
190
+ max_tokens=kwargs.get('max_new_tokens', 512),
191
+ temperature=kwargs.get('temperature', 0.2),
192
+ top_p=kwargs.get('top_p', 0.9),
193
+ )
194
+
195
+ answer = completion.choices[0].message.content
196
+ return answer
197
+ except Exception as e:
198
+ logger.error(f"Error calling Inference API: {e}", exc_info=True)
199
+ return f"Error generating answer: {str(e)}"
200
+
201
+ def enhance_readability(self, answer: str, target_level: str = "middle_school") -> Tuple[str, float]:
202
+ """Enhance readability using Inference API"""
203
+ try:
204
+ # Define prompts for different reading levels (same as bot.py)
205
+ if target_level == "middle_school":
206
+ level_description = "middle school reading level (ages 12-14, 6th-8th grade)"
207
+ instructions = """
208
+ - Use simpler medical terms or explain them
209
+ - Medium-length sentences
210
+ - Clear, structured explanations
211
+ - Keep important medical information accessible"""
212
+ elif target_level == "high_school":
213
+ level_description = "high school reading level (ages 15-18, 9th-12th grade)"
214
+ instructions = """
215
+ - Use appropriate medical terminology with context
216
+ - Varied sentence length
217
+ - Comprehensive yet accessible explanations
218
+ - Maintain technical accuracy while ensuring clarity"""
219
+ elif target_level == "college":
220
+ level_description = "college reading level (undergraduate level, ages 18-22)"
221
+ instructions = """
222
+ - Use standard medical terminology with brief explanations
223
+ - Professional and clear writing style
224
+ - Include relevant clinical context
225
+ - Maintain scientific accuracy and precision
226
+ - Appropriate for undergraduate students in health sciences"""
227
+ elif target_level == "doctoral":
228
+ level_description = "doctoral/professional reading level (graduate level, medical professionals)"
229
+ instructions = """
230
+ - Use advanced medical and scientific terminology
231
+ - Include detailed clinical and research context
232
+ - Reference specific mechanisms, pathways, and evidence
233
+ - Provide comprehensive technical explanations
234
+ - Appropriate for medical professionals, researchers, and graduate students
235
+ - Include nuanced discussions of clinical implications and research findings"""
236
+ else:
237
+ raise ValueError(f"Unknown target_level: {target_level}")
238
+
239
+ # Create messages for chat API
240
+ system_message = f"""You are a helpful medical assistant who specializes in explaining complex medical information at appropriate reading levels. Rewrite the following medical answer for {level_description}:
241
+ {instructions}
242
+ - Keep the same important information but adapt the complexity
243
+ - Provide context for technical terms
244
+ - Ensure the answer is informative yet understandable"""
245
+
246
+ user_message = f"Please rewrite this medical answer for {level_description}:\n\n{answer}"
247
+
248
+ messages = [
249
+ {"role": "system", "content": system_message},
250
+ {"role": "user", "content": user_message}
251
+ ]
252
+
253
+ # Call Inference API
254
+ completion = self.client.chat.completions.create(
255
+ model=self.current_model,
256
+ messages=messages,
257
+ max_tokens=512 if target_level in ["college", "doctoral"] else 384,
258
+ temperature=0.4 if target_level in ["college", "doctoral"] else 0.3,
259
+ )
260
+
261
+ enhanced_answer = completion.choices[0].message.content
262
+ # Clean the answer (same as bot.py)
263
+ cleaned = self.bot._clean_readability_answer(enhanced_answer, target_level)
264
+
265
+ # Calculate Flesch score
266
+ try:
267
+ flesch_score = textstat.flesch_kincaid_grade(cleaned)
268
+ except:
269
+ flesch_score = 0.0
270
+
271
+ return cleaned, flesch_score
272
+ except Exception as e:
273
+ logger.error(f"Error enhancing readability: {e}", exc_info=True)
274
+ return answer, 0.0
275
+
276
+ # Delegate other methods to bot
277
+ def format_prompt(self, context_chunks: List[Chunk], question: str) -> str:
278
+ return self.bot.format_prompt(context_chunks, question)
279
+
280
+ def retrieve_with_scores(self, query: str, k: int) -> Tuple[List[Chunk], List[float]]:
281
+ return self.bot.retrieve_with_scores(query, k)
282
+
283
+ def _categorize_question(self, question: str) -> str:
284
+ return self.bot._categorize_question(question)
285
+
286
+ @property
287
+ def args(self):
288
+ return self.bot.args
289
+
290
+ @property
291
+ def vector_retriever(self):
292
+ return self.bot.vector_retriever
293
+
294
+
295
  class GradioRAGInterface:
296
  """Wrapper class to integrate RAGBot with Gradio"""
297
 
298
+ def __init__(self, initial_bot: RAGBot, use_inference_api: bool = False):
299
+ # Check if we should use Inference API (on Spaces)
300
+ if use_inference_api and HF_INFERENCE_AVAILABLE:
301
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
302
+ if hf_token:
303
+ self.bot = InferenceAPIBot(initial_bot, hf_token)
304
+ self.use_inference_api = True
305
+ logger.info("Using Hugging Face Inference API")
306
+ else:
307
+ logger.warning("HF_TOKEN not found, falling back to local model")
308
+ self.bot = initial_bot
309
+ self.use_inference_api = False
310
+ else:
311
+ self.bot = initial_bot
312
+ self.use_inference_api = False
313
+
314
+ self.current_model = self.bot.current_model
315
  self.data_dir = initial_bot.args.data_dir
316
  logger.info("GradioRAGInterface initialized")
317
 
 
341
  return f"Model already loaded: {model_short_name}"
342
 
343
  try:
344
+ logger.info(f"Switching model from {self.current_model} to {new_model_path}")
 
 
 
345
 
346
+ if self.use_inference_api:
347
+ # For Inference API, just update the model name
348
+ self.bot.current_model = new_model_path
349
+ self.current_model = new_model_path
350
+ return f"✓ Model switched to: {model_short_name} (using Inference API)"
351
+ else:
352
+ # For local model, reload it
353
+ # Update args
354
+ self.bot.args.model = new_model_path
355
+
356
+ # Clear old model from memory
357
+ if hasattr(self.bot, 'model') and self.bot.model is not None:
358
+ del self.bot.model
359
+ del self.bot.tokenizer
360
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
361
+
362
+ # Load new model
363
+ self.bot._load_model()
364
+ self.current_model = new_model_path
365
+
366
+ return f"✓ Model loaded: {model_short_name}"
367
  except Exception as e:
368
  logger.error(f"Error reloading model: {e}", exc_info=True)
369
  return f"✗ Error loading model: {str(e)}"
 
548
  )
549
 
550
 
551
+ def create_interface(initial_bot: RAGBot, use_inference_api: bool = False) -> gr.Blocks:
552
  """Create and configure the Gradio interface"""
553
 
554
+ # Use Inference API on Spaces, local model otherwise
555
+ if use_inference_api is None:
556
+ use_inference_api = os.getenv("SPACE_ID") is not None or os.getenv("SYSTEM") == "spaces"
557
+
558
+ interface = GradioRAGInterface(initial_bot, use_inference_api=use_inference_api)
559
 
560
  # Get initial model name from bot
561
  initial_model_short = None
 
845
  parser.add_argument('--seed', type=int, default=42)
846
 
847
  args = parser.parse_args([]) # Empty args for Spaces
848
+
849
+ # Create bot but skip model loading (we'll use Inference API)
850
+ # We still need the vector database
851
+ # Set a flag to skip model loading
852
+ args.skip_model_loading = True
853
  bot = RAGBot(args)
854
+
855
+ # Don't load the model - we'll use Inference API
856
+ # Just verify vector DB is available
857
+ if bot.vector_retriever is None:
858
+ raise Exception("Vector database not available")
859
+
860
+ # Use Inference API instead of loading model
861
+ return create_interface(bot, use_inference_api=True)
862
  except Exception as e:
863
  logger.error(f"Error creating demo for Spaces: {e}", exc_info=True)
864
  # Return a simple error demo