Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig, pipeline # <-- Changed import here | |
| import re | |
| import os | |
| import torch | |
| import threading | |
| import time | |
| # --- Model Configuration --- | |
| MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3" | |
| N_CTX = 2048 | |
| MAX_TOKENS = 500 | |
| TEMPERATURE = 0.7 | |
| TOP_P = 0.9 | |
| STOP_SEQUENCES = ["USER:", "\n\n"] | |
| # --- Safety Configuration --- | |
| print("Loading safety model (unitary/toxic-bert)...") | |
| try: | |
| # Using the directly imported pipeline function | |
| safety_classifier = pipeline( | |
| "text-classification", | |
| model="unitary/toxic-bert", | |
| framework="pt" | |
| ) | |
| print("Safety model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading safety model: {e}") | |
| exit(1) | |
| TOXICITY_THRESHOLD = 0.9 | |
| def is_text_safe(text: str) -> tuple[bool, str | None]: | |
| if not text.strip(): | |
| return True, None | |
| try: | |
| results = safety_classifier(text) | |
| if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD: | |
| print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})") | |
| return False, results[0]['label'] | |
| return True, None | |
| except Exception as e: | |
| print(f"Error during safety check: {e}") | |
| return False, "safety_check_failed" | |
| # --- Main Model Loading (using Transformers) --- | |
| print(f"Loading tokenizer for {MODEL_REPO_ID}...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID) | |
| print("Tokenizer loaded.") | |
| except Exception as e: | |
| print(f"Error loading tokenizer: {e}") | |
| print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.") | |
| exit(1) | |
| print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...") | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32) | |
| model.eval() | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.") | |
| exit(1) | |
| # Configure generation for streaming | |
| generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID) | |
| generation_config.max_new_tokens = MAX_TOKENS | |
| generation_config.temperature = TEMPERATURE | |
| generation_config.top_p = TOP_P | |
| generation_config.do_sample = True | |
| generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1 | |
| generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1 | |
| if generation_config.pad_token_id == -1: | |
| generation_config.pad_token_id = 0 | |
| # --- Custom Streamer for Gradio and Safety Check --- | |
| class GradioSafetyStreamer(TextIteratorStreamer): | |
| def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs): | |
| super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs) | |
| self.safety_checker_fn = safety_checker_fn | |
| self.toxicity_threshold = toxicity_threshold | |
| self.current_sentence_buffer = "" | |
| self.output_queue = [] | |
| self.sentence_regex = re.compile(r'[.!?]\s*') | |
| self.text_done = threading.Event() | |
| def on_finalized_text(self, text: str, stream_end: bool = False): | |
| self.current_sentence_buffer += text | |
| sentences = self.sentence_regex.split(self.current_sentence_buffer) | |
| sentences_to_process = [] | |
| if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None: | |
| sentences_to_process = sentences[:-1] | |
| self.current_sentence_buffer = sentences[-1] | |
| else: | |
| sentences_to_process = sentences | |
| self.current_sentence_buffer = "" | |
| for sentence in sentences_to_process: | |
| if not sentence.strip(): continue | |
| is_safe, detected_label = self.safety_checker_fn(sentence) | |
| if not is_safe: | |
| print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})") | |
| self.output_queue.append("[Content removed due to safety guidelines]") | |
| self.output_queue.append("__STOP_GENERATION__") | |
| return | |
| else: | |
| self.output_queue.append(sentence) | |
| if stream_end: | |
| if self.current_sentence_buffer.strip(): | |
| is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer) | |
| if not is_safe: | |
| self.output_queue.append("[Content removed due to safety guidelines]") | |
| else: | |
| self.output_queue.append(self.current_sentence_buffer) | |
| self.current_sentence_buffer = "" | |
| self.text_done.set() | |
| def __iter__(self): | |
| while True: | |
| if self.output_queue: | |
| item = self.output_queue.pop(0) | |
| if item == "__STOP_GENERATION__": | |
| raise StopIteration | |
| yield item | |
| elif self.text_done.is_set(): | |
| raise StopIteration | |
| else: | |
| time.sleep(0.01) | |
| # --- Inference Function with Safety and Streaming --- | |
| def generate_word_by_word_with_safety(prompt_text: str): | |
| formatted_prompt = f"USER: {prompt_text}\nASSISTANT:" | |
| input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device) | |
| streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD) | |
| generate_kwargs = { | |
| "input_ids": input_ids, | |
| "streamer": streamer, | |
| "generation_config": generation_config, | |
| "do_sample": True, | |
| "temperature": TEMPERATURE, | |
| "top_p": TOP_P, | |
| "max_new_tokens": MAX_TOKENS, | |
| "eos_token_id": generation_config.eos_token_id, | |
| "pad_token_id": generation_config.pad_token_id, | |
| } | |
| thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
| thread.start() | |
| full_generated_text = "" | |
| try: | |
| for new_sentence_or_chunk in streamer: | |
| full_generated_text += new_sentence_or_chunk | |
| yield full_generated_text | |
| except StopIteration: | |
| pass | |
| except Exception as e: | |
| print(f"Error during streaming: {e}") | |
| yield full_generated_text + f"\n\n[Error during streaming: {e}]" | |
| finally: | |
| thread.join() | |
| # --- Gradio Blocks Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter) | |
| Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model. | |
| **β οΈ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.** | |
| All generated sentences are checked for safety using an AI filter; unsafe content will be replaced. | |
| """ | |
| ) | |
| with gr.Row(): | |
| user_prompt = gr.Textbox( | |
| lines=5, | |
| label="Enter your prompt here:", | |
| placeholder="e.g., Explain the concept of quantum entanglement in simple terms.", | |
| scale=4 | |
| ) | |
| generated_text = gr.Textbox(label="Generated Text", show_copy_button=True, scale=6) | |
| send_button = gr.Button("Send", variant="primary") | |
| send_button.click( | |
| fn=generate_word_by_word_with_safety, | |
| inputs=user_prompt, | |
| outputs=generated_text, | |
| api_name="predict", | |
| ) | |
| if __name__ == "__main__": | |
| print("Launching Gradio app...") | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |