Files changed (1) hide show
  1. model.py +203 -0
model.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import logging
5
+ from typing import Dict, Any, Optional
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from peft import PeftModel
8
+ import time
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class AgriQAAssistant:
13
+
14
+ def __init__(self, model_path: str = "nada013/agriqa-assistant"):
15
+ self.model_path = model_path
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ self.model = None
18
+ self.tokenizer = None
19
+ self.config = None
20
+
21
+ self.load_model()
22
+
23
+ def load_model(self):
24
+
25
+ logger.info(f"Loading model from Hugging Face: {self.model_path}")
26
+
27
+ try:
28
+ # Configuration for the uploaded model
29
+ self.config = {
30
+ 'base_model': 'Qwen/Qwen1.5-1.8B-Chat',
31
+ 'generation_config': {
32
+ 'max_new_tokens': 512, # Increased for complete responses
33
+ 'do_sample': True,
34
+ 'temperature': 0.3, # Lower temperature for more consistent, structured responses
35
+ 'top_p': 0.85, # Slightly lower for more focused sampling
36
+ 'top_k': 40, # Lower for more focused responses
37
+ 'repetition_penalty': 1.2, # Higher penalty to avoid repetition
38
+ 'length_penalty': 1.1, # Encourage slightly longer, detailed responses
39
+ 'no_repeat_ngram_size': 3 # Avoid repeating 3-grams
40
+ }
41
+ }
42
+
43
+ # Load tokenizer from base model
44
+ logger.info("Loading tokenizer from base model...")
45
+ self.tokenizer = AutoTokenizer.from_pretrained(
46
+ self.config['base_model'],
47
+ trust_remote_code=True
48
+ )
49
+
50
+ if self.tokenizer.pad_token is None:
51
+ self.tokenizer.pad_token = self.tokenizer.eos_token
52
+
53
+ # Try to load the model directly from Hugging Face first
54
+ try:
55
+ logger.info("Attempting to load model directly from Hugging Face...")
56
+ self.model = AutoModelForCausalLM.from_pretrained(
57
+ self.model_path,
58
+ torch_dtype=torch.float16,
59
+ device_map="auto",
60
+ trust_remote_code=True,
61
+ attn_implementation="eager",
62
+ use_flash_attention_2=False
63
+ )
64
+ logger.info("Model loaded directly from Hugging Face successfully")
65
+ except Exception as direct_load_error:
66
+ logger.info(f"Direct loading failed: {direct_load_error}")
67
+ logger.info("Falling back to base model + LoRA adapter approach...")
68
+
69
+ # Load base model first
70
+ logger.info("Loading base model...")
71
+ base_model = AutoModelForCausalLM.from_pretrained(
72
+ self.config['base_model'],
73
+ torch_dtype=torch.float16,
74
+ device_map="auto"
75
+ )
76
+
77
+ # Try to load the LoRA adapter
78
+ try:
79
+ logger.info("Loading LoRA adapter from Hugging Face...")
80
+ self.model = PeftModel.from_pretrained(
81
+ base_model,
82
+ self.model_path,
83
+ torch_dtype=torch.float16,
84
+ device_map="auto"
85
+ )
86
+ logger.info("LoRA adapter loaded successfully")
87
+ except Exception as lora_error:
88
+ logger.warning(f"LoRA adapter loading failed: {lora_error}")
89
+ logger.info("Using base model without LoRA adapter...")
90
+ self.model = base_model
91
+
92
+ # Set to evaluation mode
93
+ self.model.eval()
94
+
95
+ # Log model information
96
+ logger.info(f"Model loaded successfully from Hugging Face")
97
+ logger.info(f"Model type: {type(self.model).__name__}")
98
+ logger.info(f"Device: {self.device}")
99
+
100
+ # Check if it's a PeftModel
101
+ if hasattr(self.model, 'peft_config'):
102
+ logger.info("LoRA adapter configuration:")
103
+ for adapter_name, config in self.model.peft_config.items():
104
+ logger.info(f" - {adapter_name}: {config.target_modules}")
105
+
106
+ except Exception as e:
107
+ logger.error(f"Failed to load model: {e}")
108
+ logger.error(f"Model path: {self.model_path}")
109
+ logger.error(f"Base model: {self.config['base_model']}")
110
+ import traceback
111
+ logger.error(f"Traceback: {traceback.format_exc()}")
112
+ raise
113
+
114
+ def format_prompt(self, question: str) -> str:
115
+ """Format the question for the model using proper format."""
116
+ # Use the tokenizer's chat template if available
117
+ if hasattr(self.tokenizer, 'apply_chat_template'):
118
+ try:
119
+ messages = [
120
+ {"role": "system", "content": "You are AgriQA, an agricultural expert assistant. Your job is to answer farmers' questions with clear, practical, and accurate steps they can directly apply in the field.\n\nWhen answering:\n1. Start with a short, direct answer to the question.\n2. Provide a numbered step-by-step solution.\n3. Include specific details like measurements, quantities, time intervals, and names of products or tools.\n4. Mention any safety precautions if needed.\n5. End with an extra tip or follow-up advice.\n\nFormat Example:\nQuestion: How to control aphid infestation in mustard crops?\nAnswer:\n1. Inspect the crop daily to detect early signs of infestation.\n2. Spray Imidacloprid 17.8% SL at a rate of 0.3 ml per liter of water.\n3. Ensure thorough coverage, especially under the leaves.\n4. Remove surrounding weeds that may host aphids.\n5. Repeat spraying after 7 days if infestation continues.\nNote: Wear gloves and a mask during spraying.\n\nAlways keep your language clear, concise, and easy to understand."},
121
+ {"role": "user", "content": question}
122
+ ]
123
+ formatted_prompt = self.tokenizer.apply_chat_template(
124
+ messages,
125
+ tokenize=False,
126
+ add_generation_prompt=True
127
+ )
128
+ return formatted_prompt
129
+ except Exception as e:
130
+ logger.warning(f"Failed to use chat template: {e}. Using fallback format.")
131
+
132
+ # Fallback format for Qwen1.5-Chat
133
+ system_prompt = "You are AgriQA, an agricultural expert assistant. Your job is to answer farmers' questions with clear, practical, and accurate steps they can directly apply in the field.\n\nWhen answering:\n1. Start with a short, direct answer to the question.\n2. Provide a numbered step-by-step solution.\n3. Include specific details like measurements, quantities, time intervals, and names of products or tools.\n4. Mention any safety precautions if needed.\n5. End with an extra tip or follow-up advice.\n\nFormat Example:\nQuestion: How to control aphid infestation in mustard crops?\nAnswer:\n1. Inspect the crop daily to detect early signs of infestation.\n2. Spray Imidacloprid 17.8% SL at a rate of 0.3 ml per liter of water.\n3. Ensure thorough coverage, especially under the leaves.\n4. Remove surrounding weeds that may host aphids.\n5. Repeat spraying after 7 days if infestation continues.\nNote: Wear gloves and a mask during spraying.\n\nAlways keep your language clear, concise, and easy to understand."
134
+ formatted_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
135
+ return formatted_prompt
136
+
137
+ def generate_response(self, question: str, max_length: Optional[int] = None) -> Dict[str, Any]:
138
+ start_time = time.time()
139
+
140
+ try:
141
+ # Format the prompt
142
+ prompt = self.format_prompt(question)
143
+
144
+ # Tokenize input
145
+ inputs = self.tokenizer(
146
+ prompt,
147
+ return_tensors="pt",
148
+ truncation=True,
149
+ max_length=2048
150
+ ).to(self.device)
151
+
152
+ # Generation parameters
153
+ gen_config = self.config['generation_config'].copy()
154
+ if max_length:
155
+ gen_config['max_new_tokens'] = max_length
156
+
157
+ # Generate response
158
+ with torch.no_grad():
159
+ outputs = self.model.generate(
160
+ **inputs,
161
+ **gen_config,
162
+ pad_token_id=self.tokenizer.eos_token_id
163
+ )
164
+
165
+ # Decode response
166
+ response = self.tokenizer.decode(
167
+ outputs[0][inputs['input_ids'].shape[1]:],
168
+ skip_special_tokens=True
169
+ ).strip()
170
+
171
+ # Calculate response time
172
+ response_time = time.time() - start_time
173
+
174
+ return {
175
+ 'answer': response,
176
+ 'response_time': response_time,
177
+ 'model_info': {
178
+ 'model_name': 'agriqa-assistant',
179
+ 'model_source': 'Hugging Face',
180
+ 'model_path': self.model_path,
181
+ 'base_model': self.config['base_model']
182
+ }
183
+ }
184
+
185
+ except Exception as e:
186
+ logger.error(f"Error generating response: {e}")
187
+ return {
188
+ 'answer': "I apologize, but I encountered an error while processing your question. Please try again.",
189
+ 'confidence': 0.0,
190
+ 'response_time': time.time() - start_time,
191
+ 'error': str(e)
192
+ }
193
+
194
+ def get_model_info(self) -> Dict[str, Any]:
195
+ """Get information about the loaded model."""
196
+ return {
197
+ 'model_name': 'agriqa-assistant',
198
+ 'model_source': 'Hugging Face',
199
+ 'model_path': self.model_path,
200
+ 'base_model': self.config['base_model'],
201
+ 'device': self.device,
202
+ 'generation_config': self.config['generation_config']
203
+ }