Spaces:
Runtime error
Runtime error
| # inspiration from -> https://huggingface.co/spaces/sitammeur/Gemma-llamacpp | |
| import os | |
| import sys | |
| from typing import List, Tuple | |
| from llama_cpp import Llama | |
| from llama_cpp_agent import LlamaCppAgent | |
| from llama_cpp_agent.providers import LlamaCppPythonProvider | |
| from llama_cpp_agent.chat_history import BasicChatHistory | |
| from llama_cpp_agent.chat_history.messages import Roles | |
| from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| # Load the Environment Variables from .env file | |
| huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
| # Download gguf model files | |
| if not os.path.exists("./models"): | |
| os.makedirs("./models") | |
| hf_hub_download( | |
| repo_id="SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest", | |
| filename="gemma_3_800M_sft_v2_translation-kazparc_latest.gguf", | |
| local_dir="./models", | |
| ) | |
| # Define the prompt markers for Gemma 3 | |
| gemma_3_prompt_markers = { | |
| Roles.system: PromptMarkers("<start_of_turn>system\n", "<end_of_turn>\n"), | |
| Roles.user: PromptMarkers("<start_of_turn>user\n", "<end_of_turn>\n"), | |
| Roles.assistant: PromptMarkers("<start_of_turn>assistant", ""), | |
| Roles.tool: PromptMarkers("", ""), | |
| } | |
| gemma_3_formatter = MessagesFormatter( | |
| pre_prompt="", | |
| prompt_markers=gemma_3_prompt_markers, | |
| include_sys_prompt_in_first_user_message=True, | |
| default_stop_sequences=["<end_of_turn>", "<start_of_turn>"], | |
| strip_prompt=False, | |
| bos_token="<bos>", | |
| eos_token="<eos>", | |
| ) | |
| # Translation direction to prompts mapping | |
| direction_to_prompts = { | |
| "English to Kazakh": { | |
| "system": "You are a professional translator. Translate the following sentence into қазақ.", | |
| "prefix": "<src=en><tgt=kk>" | |
| }, | |
| "Kazakh to English": { | |
| "system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", | |
| "prefix": "<src=kk><tgt=en>" | |
| }, | |
| "Kazakh to Russian": { | |
| "system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді орыс тіліне аударыңыз.", | |
| "prefix": "<src=kk><tgt=ru>" | |
| }, | |
| "Russian to Kazakh": { | |
| "system": "Вы профессиональный переводчик. Переведите следующее предложение на қазақ язык.", | |
| "prefix": "<src=ru><tgt=kk>" | |
| } | |
| } | |
| llm = None | |
| llm_model = None | |
| def respond( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| model: str = "gemma_3_800M_sft_v2_translation-kazparc_latest.gguf", | |
| direction: str = "English to Kazakh", | |
| max_tokens: int = 64, | |
| temperature: float = 0.7, | |
| top_p: float = 0.95, | |
| top_k: int = 40, | |
| repeat_penalty: float = 1.1, | |
| ): | |
| """ | |
| Respond to a message by translating it using the specified direction. | |
| Args: | |
| message (str): The text to translate. | |
| history (List[Tuple[str, str]]): The chat history. | |
| direction (str): The translation direction (e.g., "English to Kazakh"). | |
| model (str): The model file to use. | |
| max_tokens (int): Maximum number of tokens to generate. | |
| temperature (float): Sampling temperature. | |
| top_p (float): Top-p sampling parameter. | |
| top_k (int): Top-k sampling parameter. | |
| repeat_penalty (float): Penalty for repetition. | |
| Yields: | |
| str: The translated text as it is generated. | |
| """ | |
| global llm, llm_model | |
| if llm is None or llm_model != model: | |
| model_path = f"models/{model}" | |
| if not os.path.exists(model_path): | |
| yield f"Error: Model file not found at {model_path}." | |
| return | |
| llm = Llama( | |
| model_path=model_path, | |
| flash_attn=False, | |
| n_gpu_layers=0, | |
| n_batch=8, | |
| n_ctx=2048, | |
| n_threads=8, | |
| n_threads_batch=8, | |
| ) | |
| llm_model = model | |
| provider = LlamaCppPythonProvider(llm) | |
| # Get system prompt and user prefix based on direction | |
| prompts = direction_to_prompts[direction] | |
| system_message = prompts["system"] | |
| user_prefix = prompts["prefix"] | |
| agent = LlamaCppAgent( | |
| provider, | |
| system_prompt=system_message, | |
| custom_messages_formatter=gemma_3_formatter, | |
| debug_output=True, | |
| ) | |
| settings = provider.get_provider_default_settings() | |
| settings.temperature = temperature | |
| settings.top_k = top_k | |
| settings.top_p = top_p | |
| settings.max_tokens = max_tokens | |
| settings.repeat_penalty = repeat_penalty | |
| settings.stream = True | |
| messages = BasicChatHistory() | |
| for user_msg, assistant_msg in history: | |
| full_user_msg = user_prefix + " " + user_msg | |
| messages.add_message({"role": Roles.user, "content": full_user_msg}) | |
| messages.add_message({"role": Roles.assistant, "content": assistant_msg}) | |
| full_message = user_prefix + " " + message | |
| stream = agent.get_chat_response( | |
| full_message, | |
| llm_sampling_settings=settings, | |
| chat_history=messages, | |
| returns_streaming_generator=True, | |
| print_output=False, | |
| ) | |
| outputs = "" | |
| for output in stream: | |
| outputs += output | |
| yield outputs | |
| demo = gr.ChatInterface( | |
| respond, | |
| examples=[["Hello"], ["Сәлем"], ["Привет"]], | |
| additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), | |
| additional_inputs=[ | |
| gr.Dropdown( | |
| choices=[ | |
| "gemma_3_800M_sft_v2_translation-kazparc_latest.gguf", | |
| ], | |
| value="gemma_3_800M_sft_v2_translation-kazparc_latest.gguf", | |
| label="Model", | |
| info="Select the AI model to use for chat", | |
| ), | |
| gr.Dropdown( | |
| choices=["English to Kazakh", "Kazakh to English", "Kazakh to Russian", "Russian to Kazakh"], | |
| label="Translation Direction", | |
| info="Select the direction of translation" | |
| ), | |
| gr.Slider( | |
| minimum=512, | |
| maximum=2048, | |
| value=1024, | |
| step=1, | |
| label="Max Tokens", | |
| info="Maximum length of the translation" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Controls randomness (higher = more creative)" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p", | |
| info="Nucleus sampling threshold" | |
| ), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=40, | |
| step=1, | |
| label="Top-k", | |
| info="Limits vocabulary to top K tokens" | |
| ), | |
| gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.1, | |
| label="Repetition Penalty", | |
| info="Penalizes repeated words" | |
| ), | |
| ], | |
| theme="Ocean", | |
| submit_btn="Translate", | |
| stop_btn="Stop", | |
| title="Kazakh Translation Model", | |
| description="Translate text between Kazakh, English, and Russian using a specialized language model.", | |
| chatbot=gr.Chatbot(scale=1, show_copy_button=True), | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_api=False, | |
| ) |