Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
| import gradio as gr | |
| # Load the custom model and tokenizer | |
| model_path = 'redael/model_udc' | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_path) | |
| model = GPT2LMHeadModel.from_pretrained(model_path) | |
| # Check if CUDA is available and use GPU if possible, enable FP16 precision | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| if device.type == 'cuda': | |
| model = model.half() # Use FP16 precision | |
| def generate_response(prompt, model, tokenizer, max_length=100, num_beams=1, temperature=0.7, top_p=0.9, repetition_penalty=2.0): | |
| # Prepare the prompt | |
| prompt = f"User: {prompt}\nAssistant:" | |
| inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device) | |
| outputs = model.generate( | |
| inputs['input_ids'], | |
| max_length=max_length, | |
| num_return_sequences=1, | |
| pad_token_id=tokenizer.eos_token_id, | |
| num_beams=num_beams, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| early_stopping=True | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Post-processing to clean up the response | |
| response = response.split("Assistant:")[-1].strip() | |
| response_lines = response.split('\n') | |
| clean_response = [] | |
| for line in response_lines: | |
| if "User:" not in line and "Assistant:" not in line: | |
| clean_response.append(line) | |
| response = ' '.join(clean_response) | |
| return response.strip() | |
| def respond(message, history): | |
| # Prepare the prompt from the history and the new message | |
| system_message = "You are a friendly chatbot." | |
| conversation = system_message + "\n" | |
| for user_message, assistant_response in history: | |
| conversation += f"User: {user_message}\nAssistant: {assistant_response}\n" | |
| conversation += f"User: {message}\nAssistant:" | |
| # Fixed values for generation parameters | |
| max_tokens = 100 # Adjusted max tokens | |
| temperature = 0.7 | |
| top_p = 0.9 | |
| response = generate_response(conversation, model, tokenizer, max_length=max_tokens, temperature=temperature, top_p=top_p) | |
| return response | |
| # Gradio Chat Interface | |
| demo = gr.ChatInterface( | |
| respond | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |