arahrooh commited on
Commit
1553f78
·
1 Parent(s): 084bec8

Optimize memory usage: use float16 on CPU and fix double loading

Browse files
Files changed (2) hide show
  1. app.py +17 -3
  2. bot.py +15 -0
app.py CHANGED
@@ -674,6 +674,8 @@ def create_demo_for_spaces():
674
  try:
675
  # Initialize with default args for Spaces
676
  parser = argparse.ArgumentParser()
 
 
677
  parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct')
678
  parser.add_argument('--vector-db-dir', default='./chroma_db')
679
  parser.add_argument('--data-dir', default='./Data Resources')
@@ -699,9 +701,21 @@ def create_demo_for_spaces():
699
  gr.Markdown(f"# Error Initializing Chatbot\n\nAn error occurred: {str(e)}")
700
  return error_demo
701
 
702
- # Create demo at module level for Hugging Face Spaces
703
- # This is what Spaces will import and use
704
- demo = create_demo_for_spaces()
 
 
 
 
 
 
 
 
 
 
 
 
705
 
706
  # For local execution
707
  if __name__ == "__main__":
 
674
  try:
675
  # Initialize with default args for Spaces
676
  parser = argparse.ArgumentParser()
677
+ # Use Llama-3.2-3B as default (will use float16 on CPU to save memory)
678
+ # For Spaces with limited memory, consider upgrading hardware tier
679
  parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct')
680
  parser.add_argument('--vector-db-dir', default='./chroma_db')
681
  parser.add_argument('--data-dir', default='./Data Resources')
 
701
  gr.Markdown(f"# Error Initializing Chatbot\n\nAn error occurred: {str(e)}")
702
  return error_demo
703
 
704
+ # For Hugging Face Spaces: lazy loading to avoid double initialization
705
+ # Only create demo when actually accessed (not at import time)
706
+ # This prevents loading the model twice
707
+ _demo_cache = None
708
+
709
+ def get_demo():
710
+ """Lazy loader for demo - only creates it once"""
711
+ global _demo_cache
712
+ if _demo_cache is None:
713
+ _demo_cache = create_demo_for_spaces()
714
+ return _demo_cache
715
+
716
+ # For Hugging Face Spaces: expose demo at module level
717
+ # Spaces will import this and use it
718
+ demo = get_demo()
719
 
720
  # For local execution
721
  if __name__ == "__main__":
bot.py CHANGED
@@ -294,6 +294,21 @@ class RAGBot:
294
  if hf_token:
295
  model_kwargs["token"] = hf_token
296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  # For MPS, use device_map; for CUDA, let it auto-detect
298
  if self.device == "mps":
299
  model_kwargs["device_map"] = self.device
 
294
  if hf_token:
295
  model_kwargs["token"] = hf_token
296
 
297
+ # Use 8-bit quantization on CPU to reduce memory usage
298
+ # This reduces memory by ~50% with minimal quality loss
299
+ if self.device == "cpu":
300
+ try:
301
+ from transformers import BitsAndBytesConfig
302
+ # Use 8-bit quantization for CPU (reduces memory significantly)
303
+ model_kwargs["load_in_8bit"] = False # 8-bit not available on CPU
304
+ # Instead, use float16 even on CPU to save memory
305
+ model_kwargs["torch_dtype"] = torch.float16
306
+ logger.info("Using float16 on CPU to reduce memory usage")
307
+ except ImportError:
308
+ # Fallback: use float16 anyway
309
+ model_kwargs["torch_dtype"] = torch.float16
310
+ logger.info("Using float16 on CPU to reduce memory usage (fallback)")
311
+
312
  # For MPS, use device_map; for CUDA, let it auto-detect
313
  if self.device == "mps":
314
  model_kwargs["device_map"] = self.device