Beibars003 commited on
Commit
97d0043
verified
1 Parent(s): fb8e4a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -185
app.py CHANGED
@@ -1,205 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import time
3
- import re
4
- from tqdm import tqdm
5
- from openai import OpenAI
6
-
7
- instructs = {'eng': 'English',
8
- 'kaz': 'Kazakh',
9
- 'rus': 'Russian',
10
- 'tur': 'Turkish',
11
- 'uzn': 'Uzbek',
12
- 'zho_simpl': 'Chinese (Simplified)'}
13
-
14
- openai_api_key = "EMPTY"
15
- openai_api_base = "http://localhost:7050/v1"
16
- model_path = "gemma_translator"
17
- client = OpenAI(
18
- api_key=openai_api_key,
19
- base_url=openai_api_base,
20
- )
21
 
 
 
22
 
23
- def build_prompt_alpaca(instruction: str, input_text: str = "") -> str:
24
- return (
25
- "<bos>Below is an instruction that describes a task, paired with an input that provides further context. "
26
- "Write a response that appropriately completes the request.\n\n"
27
- f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
28
- )
29
-
30
-
31
- def get_prediction(instruction, input_text, temperature=0.0, top_p=0.9, max_tokens=1024,
32
- presence_penalty=0.0, frequency_penalty=0.3, repetition_penalty=1.0,
33
- top_k=64, min_p=0.0, retry_count=3, sleep_time=1):
34
- for attempt in range(retry_count):
35
- try:
36
- prompt = build_prompt_alpaca(instruction, input_text)
37
- response = client.completions.create(
38
- model=model_path,
39
- prompt=prompt,
40
- temperature=temperature,
41
- top_p=top_p,
42
- max_tokens=max_tokens,
43
- frequency_penalty=frequency_penalty,
44
- stream=True
45
- )
46
- return response
47
-
48
- except Exception as e:
49
- print(f"Error on attempt {attempt+1}: {str(e)}")
50
- if attempt < retry_count - 1:
51
- print(f"Retrying in {sleep_time} seconds...")
52
- time.sleep(sleep_time)
53
- sleep_time *= 2 # Exponential backoff
54
- else:
55
- print("Max retries reached. Returning empty string.")
56
- return ""
57
-
58
-
59
- def process_streaming_response(response):
60
- """Process streaming response and return complete text"""
61
- if not response:
62
- return "Error: No response received"
63
-
64
- buffer = ""
65
- is_first_chunk = True
66
- complete_text = ""
67
-
68
- try:
69
- for chunk in response:
70
- if hasattr(chunk, 'choices') and chunk.choices and chunk.choices[0].text:
71
- text_content = chunk.choices[0].text
72
-
73
- if is_first_chunk:
74
- text_content = text_content.lstrip()
75
- if text_content:
76
- text_content = " " + text_content
77
- is_first_chunk = False
78
-
79
- buffer += text_content
80
- complete_text += text_content
81
-
82
- return complete_text.strip()
83
- except Exception as e:
84
- return f"Error processing response: {str(e)}"
85
 
 
 
 
 
 
86
 
87
- def generate_translation(text, target_lang, temperature, top_p, max_tokens,
88
- presence_penalty, frequency_penalty, repetition_penalty,
89
- top_k, min_p, use_v0_prompt=False):
90
- """Updated function that accepts all parameters from Gradio"""
91
- if not text.strip():
92
- return "Please enter some text to translate."
93
-
94
- text = text.strip()
 
 
 
 
 
 
 
 
 
95
 
96
- # Build instruction based on target language
97
- if use_v0_prompt:
98
- instruction = f"Translate the following text into {instructs[target_lang]}."
99
- else:
100
- instruction = f"Translate to {instructs[target_lang]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  try:
103
- # Use the get_prediction function for inference
104
- response = get_prediction(
105
- instruction=instruction,
106
- input_text=text,
107
- temperature=temperature,
108
- top_p=top_p,
109
- max_tokens=int(max_tokens),
110
- presence_penalty=presence_penalty,
111
- frequency_penalty=frequency_penalty,
112
- repetition_penalty=repetition_penalty,
113
- top_k=int(top_k),
114
- min_p=min_p,
115
- retry_count=3,
116
- sleep_time=1
117
- )
118
-
119
- # Process the streaming response
120
- return process_streaming_response(response)
121
-
122
- except Exception as e:
123
- return f"Error: {str(e)}"
124
-
125
-
126
- def set_example_text(example_text):
127
- """Helper function to set example text"""
128
- return example_text
129
-
130
-
131
- # Gradio UI
132
- with gr.Blocks() as demo:
133
- gr.Markdown("## 馃寪 Multilingual Translation App")
134
-
135
- with gr.Row():
136
- input_text = gr.Textbox(
137
- label="Enter your text",
138
- placeholder="Type here and press Enter or click Translate",
139
- lines=3
140
  )
141
 
142
- with gr.Row():
143
- lang_dropdown = gr.Dropdown(
144
- choices=list(instructs.keys()),
145
- value="kaz",
146
- label="Translate to"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  )
148
 
149
- with gr.Accordion("Advanced Parameters (Optional)", open=False):
150
- temperature_slider = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Temperature")
151
- top_p_slider = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p")
152
- max_tokens_slider = gr.Slider(64, 2048, value=1024, step=1, label="Max Tokens")
153
- presence_penalty_slider = gr.Slider(-2.0, 2.0, value=0.0, step=0.01, label="Presence Penalty", info="Usually left at 0.0")
154
- frequency_penalty_slider = gr.Slider(-2.0, 2.0, value=0.3, step=0.01, label="Frequency Penalty")
155
- repetition_penalty_slider = gr.Slider(0.5, 2.0, value=1.0, step=0.01, label="Repetition Penalty", info="Default is 1.0")
156
- top_k_slider = gr.Slider(1, 100, value=64, step=1, label="Top-k", info="Default 64")
157
- min_p_slider = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Min-p", info="Usually 0.0")
158
- use_v0_prompt = gr.Checkbox(label="Use v0 Prompt Format", value=False)
159
-
160
- with gr.Row():
161
- submit_btn = gr.Button("Translate", variant="primary")
162
-
163
- output_text = gr.Textbox(label="Translation Result", lines=4)
164
-
165
- # Define inputs in the correct order to match the function parameters
166
- inputs = [
167
- input_text, lang_dropdown,
168
- temperature_slider, top_p_slider, max_tokens_slider,
169
- presence_penalty_slider, frequency_penalty_slider, repetition_penalty_slider,
170
- top_k_slider, min_p_slider, use_v0_prompt
171
- ]
172
-
173
- # Connect the function to the button and text input
174
- submit_btn.click(fn=generate_translation, inputs=inputs, outputs=output_text)
175
- input_text.submit(fn=generate_translation, inputs=inputs, outputs=output_text)
176
-
177
- # Example inputs
178
- gr.Markdown("### 馃攳 Examples:")
179
- with gr.Row():
180
- examples = [
181
- "Hello! How can I help you?",
182
- "Hello! how can I help you?",
183
- "2 + 2 is?",
184
- "Your appointment is on 5th July at 3 PM.",
185
- "The total cost is 1250 KZT.",
186
- "She was born in 1995."
187
- ]
188
-
189
- for example_text in examples:
190
- example_btn = gr.Button(example_text, size="sm")
191
- example_btn.click(
192
- fn=lambda x=example_text: x,
193
- inputs=[],
194
- outputs=[input_text]
195
- )
196
-
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  if __name__ == "__main__":
200
  demo.launch(
201
  share=False,
202
  server_name="0.0.0.0",
203
- server_port=5482,
204
  show_api=False,
205
  )
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+
4
+ import os
5
+ import sys
6
+ from typing import List, Tuple
7
+ from llama_cpp import Llama
8
+ from llama_cpp_agent import LlamaCppAgent
9
+ from llama_cpp_agent.providers import LlamaCppPythonProvider
10
+ from llama_cpp_agent.chat_history import BasicChatHistory
11
+ from llama_cpp_agent.chat_history.messages import Roles
12
+ from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers
13
+ from huggingface_hub import hf_hub_download
14
  import gradio as gr
15
+ from logger import logging
16
+ from exception import CustomExceptionHandling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Load the Environment Variables from .env file
19
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
20
 
21
+ # Download gguf model files
22
+ if not os.path.exists("./models"):
23
+ os.makedirs("./models")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ hf_hub_download(
26
+ repo_id="SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest",
27
+ filename="gemma_3_800M_sft_v2_translation-kazparc_latest.gguf",
28
+ local_dir="./models",
29
+ )
30
 
31
+ # Define the prompt markers for Gemma 3
32
+ gemma_3_prompt_markers = {
33
+ Roles.system: PromptMarkers("<start_of_turn>system\n", "<end_of_turn>\n"),
34
+ Roles.user: PromptMarkers("<start_of_turn>user\n", "<end_of_turn>\n"),
35
+ Roles.assistant: PromptMarkers("<start_of_turn>assistant", ""),
36
+ Roles.tool: PromptMarkers("", ""),
37
+ }
38
+
39
+ gemma_3_formatter = MessagesFormatter(
40
+ pre_prompt="",
41
+ prompt_markers=gemma_3_prompt_markers,
42
+ include_sys_prompt_in_first_user_message=True,
43
+ default_stop_sequences=["<end_of_turn>", "<start_of_turn>"],
44
+ strip_prompt=False,
45
+ bos_token="<bos>",
46
+ eos_token="<eos>",
47
+ )
48
 
49
+ # Translation direction to prompts mapping
50
+ direction_to_prompts = {
51
+ "English to Kazakh": {
52
+ "system": "You are a professional translator. Translate the following sentence into 覜邪蟹邪覜.",
53
+ "prefix": "<src=en><tgt=kk>"
54
+ },
55
+ "Kazakh to English": {
56
+ "system": "小褨蟹 泻訖褋褨斜懈 邪褍写邪褉屑邪褕褘褋褘蟹. 孝萤屑械薪写械谐褨 褋萤泄谢械屑写褨 English 褌褨谢褨薪械 邪褍写邪褉褘遥褘蟹.",
57
+ "prefix": "<src=kk><tgt=en>"
58
+ },
59
+ "Kazakh to Russian": {
60
+ "system": "小褨蟹 泻訖褋褨斜懈 邪褍写邪褉屑邪褕褘褋褘蟹. 孝萤屑械薪写械谐褨 褋萤泄谢械屑写褨 芯褉褘褋 褌褨谢褨薪械 邪褍写邪褉褘遥褘蟹.",
61
+ "prefix": "<src=kk><tgt=ru>"
62
+ },
63
+ "Russian to Kazakh": {
64
+ "system": "袙褘 锌褉芯褎械褋褋懈芯薪邪谢褜薪褘泄 锌械褉械胁芯写褔懈泻. 袩械褉械胁械写懈褌械 褋谢械写褍褞褖械械 锌褉械写谢芯卸械薪懈械 薪邪 覜邪蟹邪覜 褟蟹褘泻.",
65
+ "prefix": "<src=ru><tgt=kk>"
66
+ }
67
+ }
68
+
69
+ llm = None
70
+ llm_model = None
71
+
72
+ def respond(
73
+ message: str,
74
+ history: List[Tuple[str, str]],
75
+ direction: str,
76
+ model: str = "gemma_3_800M_sft_v2_translation-kazparc_latest.gguf",
77
+ max_tokens: int = 1024,
78
+ temperature: float = 0.7,
79
+ top_p: float = 0.95,
80
+ top_k: int = 40,
81
+ repeat_penalty: float = 1.1,
82
+ ):
83
+ """
84
+ Respond to a message by translating it using the specified direction.
85
 
86
+ Args:
87
+ message (str): The text to translate.
88
+ history (List[Tuple[str, str]]): The chat history.
89
+ direction (str): The translation direction (e.g., "English to Kazakh").
90
+ model (str): The model file to use.
91
+ max_tokens (int): Maximum number of tokens to generate.
92
+ temperature (float): Sampling temperature.
93
+ top_p (float): Top-p sampling parameter.
94
+ top_k (int): Top-k sampling parameter.
95
+ repeat_penalty (float): Penalty for repetition.
96
+
97
+ Yields:
98
+ str: The translated text as it is generated.
99
+ """
100
  try:
101
+ global llm, llm_model
102
+ if llm is None or llm_model != model:
103
+ model_path = f"models/{model}"
104
+ if not os.path.exists(model_path):
105
+ yield f"Error: Model file not found at {model_path}."
106
+ return
107
+ llm = Llama(
108
+ model_path=model_path,
109
+ flash_attn=False,
110
+ n_gpu_layers=0,
111
+ n_batch=8,
112
+ n_ctx=2048,
113
+ n_threads=8,
114
+ n_threads_batch=8,
115
+ )
116
+ llm_model = model
117
+ provider = LlamaCppPythonProvider(llm)
118
+
119
+ # Get system prompt and user prefix based on direction
120
+ prompts = direction_to_prompts[direction]
121
+ system_message = prompts["system"]
122
+ user_prefix = prompts["prefix"]
123
+
124
+ agent = LlamaCppAgent(
125
+ provider,
126
+ system_prompt=system_message,
127
+ custom_messages_formatter=gemma_3_formatter,
128
+ debug_output=True,
 
 
 
 
 
 
 
 
 
129
  )
130
 
131
+ settings = provider.get_provider_default_settings()
132
+ settings.temperature = temperature
133
+ settings.top_k = top_k
134
+ settings.top_p = top_p
135
+ settings.max_tokens = max_tokens
136
+ settings.repeat_penalty = repeat_penalty
137
+ settings.stream = True
138
+
139
+ messages = BasicChatHistory()
140
+ for user_msg, assistant_msg in history:
141
+ full_user_msg = user_prefix + " " + user_msg
142
+ messages.add_message({"role": Roles.user, "content": full_user_msg})
143
+ messages.add_message({"role": Roles.assistant, "content": assistant_msg})
144
+
145
+ full_message = user_prefix + " " + message
146
+
147
+ stream = agent.get_chat_response(
148
+ full_message,
149
+ llm_sampling_settings=settings,
150
+ chat_history=messages,
151
+ returns_streaming_generator=True,
152
+ print_output=False,
153
  )
154
 
155
+ logging.info("Response stream generated successfully")
156
+ outputs = ""
157
+ for output in stream:
158
+ outputs += output
159
+ yield outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ except Exception as e:
162
+ raise CustomExceptionHandling(e, sys) from e
163
+
164
+ demo = gr.ChatInterface(
165
+ respond,
166
+ examples=[["Hello"], ["小訖谢械屑"], ["袩褉懈胁械褌"]],
167
+ additional_inputs_accordion=gr.Accordion(label="鈿欙笍 Parameters", open=False, render=False),
168
+ additional_inputs=[
169
+ gr.Dropdown(
170
+ choices=["English to Kazakh", "Kazakh to English", "Kazakh to Russian", "Russian to Kazakh"],
171
+ label="Translation Direction",
172
+ info="Select the direction of translation"
173
+ ),
174
+ gr.Slider(
175
+ minimum=512,
176
+ maximum=2048,
177
+ value=1024,
178
+ step=1,
179
+ label="Max Tokens",
180
+ info="Maximum length of the translation"
181
+ ),
182
+ gr.Slider(
183
+ minimum=0.1,
184
+ maximum=2.0,
185
+ value=0.7,
186
+ step=0.1,
187
+ label="Temperature",
188
+ info="Controls randomness (higher = more creative)"
189
+ ),
190
+ gr.Slider(
191
+ minimum=0.1,
192
+ maximum=1.0,
193
+ value=0.95,
194
+ step=0.05,
195
+ label="Top-p",
196
+ info="Nucleus sampling threshold"
197
+ ),
198
+ gr.Slider(
199
+ minimum=1,
200
+ maximum=100,
201
+ value=40,
202
+ step=1,
203
+ label="Top-k",
204
+ info="Limits vocabulary to top K tokens"
205
+ ),
206
+ gr.Slider(
207
+ minimum=1.0,
208
+ maximum=2.0,
209
+ value=1.1,
210
+ step=0.1,
211
+ label="Repetition Penalty",
212
+ info="Penalizes repeated words"
213
+ ),
214
+ ],
215
+ theme="Ocean",
216
+ submit_btn="Translate",
217
+ stop_btn="Stop",
218
+ title="Kazakh Translation Model",
219
+ description="Translate text between Kazakh, English, and Russian using a specialized language model.",
220
+ chatbot=gr.Chatbot(scale=1, show_copy_button=True),
221
+ cache_examples=False,
222
+ )
223
 
224
  if __name__ == "__main__":
225
  demo.launch(
226
  share=False,
227
  server_name="0.0.0.0",
228
+ server_port=7860,
229
  show_api=False,
230
  )