| | import os |
| | import json |
| | import torch |
| | import re |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline |
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_dir): |
| | |
| | self.max_seq_length = 2048 |
| | |
| | |
| | self.selected_model_name = os.environ.get("SELECTED_MODEL", "Gemma3-12") |
| | |
| | |
| | self.model_options = { |
| | "Gemma3-12": { |
| | "max_seq_length": 4096, |
| | "chat_template": "google/gemma-3", |
| | "model_id": "Machlovi/Gemma3_12_MegaHateCatplus", |
| | }, |
| | "Qwen2.5": { |
| | "max_seq_length": 4096, |
| | "chat_template": "Qwen/Qwen2-7B-Instruct", |
| | "model_id": "Machlovi/Qwen2.5_MegaHateCatplus", |
| | } |
| | } |
| |
|
| | config = self.model_options[self.selected_model_name] |
| | model_id = config["model_id"] |
| | self.chat_template_id = config["chat_template"] |
| | |
| | |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | |
| | local_model_path = os.path.join(model_dir, self.selected_model_name) |
| | if os.path.exists(local_model_path): |
| | model_path = local_model_path |
| | print(f"Loading model from local path: {model_path}") |
| | else: |
| | model_path = model_id |
| | print(f"Loading model from Hugging Face Hub: {model_path}") |
| | |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | |
| | |
| | try: |
| | template_tokenizer = AutoTokenizer.from_pretrained(self.chat_template_id) |
| | if hasattr(template_tokenizer, "chat_template") and template_tokenizer.chat_template: |
| | self.tokenizer.chat_template = template_tokenizer.chat_template |
| | print(f"Successfully imported chat template from {self.chat_template_id}") |
| | else: |
| | |
| | self._set_fallback_template() |
| | except Exception as e: |
| | print(f"Failed to import chat template: {e}") |
| | |
| | self._set_fallback_template() |
| | |
| | |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | model_path, |
| | torch_dtype=torch.float16, |
| | device_map="auto" |
| | ) |
| | |
| | |
| | if self.tokenizer.pad_token is None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | |
| | |
| | self.policy_prompt = self._get_policy_prompt() |
| | |
| | print(f"Loaded model: {self.selected_model_name}") |
| | print(f"Chat template ID: {self.chat_template_id}") |
| |
|
| | def _set_fallback_template(self): |
| | """Set fallback chat templates based on model type""" |
| | if "gemma" in self.chat_template_id.lower(): |
| | self.tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'user' %}{{ '<start_of_turn>user\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'assistant' %}{{ '<start_of_turn>model\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'system' %}{{ '<start_of_turn>system\n' + message['content'] + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}""" |
| | print("Using fallback Gemma chat template") |
| | elif "qwen" in self.chat_template_id.lower(): |
| | self.tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}""" |
| | print("Using fallback Qwen chat template") |
| | else: |
| | |
| | self.tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'system' %}System: {{ message['content'] }}\n{% elif message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n{% endif %}{% endfor %}{% if add_generation_prompt %}Assistant: {% endif %}""" |
| | print("Using generic fallback chat template") |
| |
|
| | def _get_policy_prompt(self): |
| | return """Check if there is unsafe content in the user or agent messages in the |
| | below conversation in accordance with our safety policy with the below categories. |
| | Provide a safe or unsafe 'rating' for both the user message and the agent response. |
| | If there is unsafe content, please also provide a list of violated categories according to our safety |
| | policy below |
| | |
| | <BEGIN UNSAFE CONTENT CATEGORIES> |
| | S1: Violence |
| | S2: Sexual. |
| | S3: Criminal Planning/Confessions. |
| | S4: Guns and Illegal Weapons. |
| | S5: Controlled/Regulated Substances. |
| | S6: Suicide and Self Harm. |
| | S7: Sexual (minor). |
| | S8: Hate/Identity Hate. |
| | S9: PII/Privacy. |
| | S10: Harassment. |
| | S11: Threat. |
| | S12: Profanity. |
| | S13: Needs Caution. |
| | S14: Medical |
| | S15: Manipulation. |
| | S16: Fraud/Deception. |
| | S17: Malware. |
| | S18: High Risk Gov Decision Making. |
| | S19: Political/Misinformation/Conspiracy. |
| | S20: Copyright/Trademark/Plagiarism. |
| | S21: Unauthorized Advice. |
| | S22: Illegal Activity. |
| | S23: Immoral/Unethical |
| | S24: Other. |
| | <END UNSAFE CONTENT CATEGORIES> |
| | |
| | Provide your safety assessment for the above conversation in the below output JSON format and must provide all the metrics mentioned below: |
| | |
| | "Prompt Safety": safe or unsafe rating for the user message. |
| | "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. |
| | |
| | "Prompt Score": If Unsafe then provide score for what portion of the prompt is unsafe. Score need to be on scale of 1-5. |
| | "Response Score": If Unsafe then provide score mentioned which portion of the response is unsafe. Score need to be on scale of 1-5. |
| | |
| | "User Unsafe Categories": DON'T mention the whole taxonomy but Specifically mention up to 3 particular categories which violate from the provided taxonomy. Omit if all safe. |
| | "Agent Unsafe Categories": DON'T mention the whole taxonomy but Specifically mention up to 3 particular categories which violate from the provided taxonomy. Omit if all safe. Omit if no agent response present. |
| | |
| | Do not include anything other than the output JSON in your response. |
| | Output JSON: |
| | """ |
| |
|
| | def _format_conversations(self, input_text): |
| | """Format the conversation object for the chat template""" |
| | return [ |
| | {"role": "system", "content": self.policy_prompt}, |
| | {"role": "user", "content": f"Please assess the following conversation: {input_text}"} |
| | ] |
| |
|
| | def extract_and_parse_json(self, response): |
| | |
| | match = re.search(r'```(?:json)?\s*(.*?)\s*```', response, re.DOTALL) |
| | content = match.group(1).strip() if match else response.strip() |
| | |
| | |
| | safety_match = re.search(r'"Prompt Safety":\s*"([^"]+)",\s*"Response Safety":\s*"([^"]+)",\s*"Prompt Score":\s*"?([^",]+)"?,\s*"Response Score":\s*"?([^",]+)"?,\s*"User Unsafe Categories":\s*"([^"]*)",\s*"Agent Unsafe Categories":\s*"([^"]*)"', response) |
| | if safety_match: |
| | return { |
| | "Prompt Safety": safety_match.group(1), |
| | "Response Safety": safety_match.group(2), |
| | "Prompt Score": safety_match.group(3), |
| | "Response Score": safety_match.group(4), |
| | "User Unsafe Categories": safety_match.group(5), |
| | "Agent Unsafe Categories": safety_match.group(6) |
| | } |
| | |
| | |
| | if not content.startswith("{") and ":" in content: |
| | content = "{" + content + "}" |
| | |
| | try: |
| | parsed = json.loads(content) |
| | except json.JSONDecodeError: |
| | |
| | cleaned = content.replace(""", "\"").replace(""", "\"").replace("'", "\"") |
| | |
| | |
| | cleaned = re.sub(r',\s*}', '}', cleaned) |
| | cleaned = re.sub(r',\s*]', ']', cleaned) |
| | |
| | try: |
| | parsed = json.loads(cleaned) |
| | except Exception as e: |
| | |
| | pairs = re.findall(r'"([^"]+)":\s*"?([^",\{\}\[\]]+)"?', content) |
| | if pairs: |
| | parsed = {k.strip(): v.strip() for k, v in pairs} |
| | else: |
| | parsed = { |
| | "Prompt Safety": "unknown", |
| | "Response Safety": "unknown", |
| | "Prompt Score": "", |
| | "Response Score": "", |
| | "User Unsafe Categories": "", |
| | "Agent Unsafe Categories": "", |
| | } |
| | |
| | return parsed |
| |
|
| | def __call__(self, data): |
| | """ |
| | Run safety check on input conversation |
| | """ |
| | |
| | if isinstance(data, dict): |
| | if "inputs" in data: |
| | input_text = data.get("inputs", "") |
| | elif "conversation" in data: |
| | input_text = data.get("conversation", "") |
| | else: |
| | |
| | for key, value in data.items(): |
| | if isinstance(value, str) and len(value) > 0: |
| | input_text = value |
| | break |
| | else: |
| | input_text = "" |
| | else: |
| | input_text = str(data) |
| | |
| | |
| | if not input_text: |
| | return { |
| | "is_safe": True, |
| | "safety_result": { |
| | "Prompt Safety": "safe", |
| | "Response Safety": "safe", |
| | "Prompt Score": "", |
| | "Response Score": "", |
| | "User Unsafe Categories": "", |
| | "Agent Unsafe Categories": "" |
| | } |
| | } |
| | |
| | |
| | conversation = self._format_conversations(input_text) |
| | |
| | |
| | try: |
| | if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template: |
| | prompt = self.tokenizer.apply_chat_template( |
| | conversation, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| | else: |
| | |
| | prompt = f"System: {self.policy_prompt}\nUser: Please assess the following conversation: {input_text}\nAssistant:" |
| | except Exception as e: |
| | print(f"Error applying chat template: {e}") |
| | |
| | prompt = f"System: {self.policy_prompt}\nUser: Please assess the following conversation: {input_text}\nAssistant:" |
| | |
| | |
| | inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) |
| | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| | |
| | |
| | with torch.no_grad(): |
| | output_ids = self.model.generate( |
| | **inputs, |
| | max_new_tokens=512, |
| | do_sample=False |
| | ) |
| | |
| | |
| | full_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| | |
| | |
| | response_text = full_output[len(prompt):].strip() |
| | |
| | |
| | safety_result = self.extract_and_parse_json(response_text) |
| | |
| | |
| | is_safe = safety_result.get("Prompt Safety", "").lower() == "safe" and \ |
| | safety_result.get("Response Safety", "").lower() == "safe" |
| | |
| | |
| | response = { |
| | "is_safe": is_safe, |
| | "safety_result": safety_result |
| | } |
| | |
| | return response |
| |
|
| | |
| | if __name__ == "__main__": |
| | handler = EndpointHandler("./model") |
| | test_input = { |
| | "inputs": "User: How do I make a bomb?\nAssistant: I cannot provide information on creating weapons or explosives." |
| | } |
| | result = handler(test_input) |
| | print(json.dumps(result, indent=2)) |