| | |
| | |
| |
|
| | from cog import BasePredictor, Input |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
| | import argparse |
| |
|
| |
|
| | class Predictor(BasePredictor): |
| | def setup(self) -> None: |
| | """Load the model into memory to make running multiple predictions efficient""" |
| | |
| | model_name = "defog/sqlcoder-34b-alpha" |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | torch_dtype=torch.float16, |
| | device_map="auto", |
| | use_cache=True, |
| | offload_folder="./.cache", |
| | ) |
| |
|
| | def predict( |
| | self, |
| | prompt: str = Input(description="Prompt to generate from"), |
| | ) -> str: |
| | """Run a single prediction on the model""" |
| | |
| | |
| | |
| |
|
| | |
| | |
| | eos_token_id = self.tokenizer.eos_token_id |
| | pipe = pipeline( |
| | "text-generation", |
| | model=self.model, |
| | tokenizer=self.tokenizer, |
| | max_length=300, |
| | do_sample=False, |
| | num_beams=5, |
| | ) |
| | generated_query = ( |
| | pipe( |
| | prompt, |
| | num_return_sequences=1, |
| | eos_token_id=eos_token_id, |
| | pad_token_id=eos_token_id, |
| | )[0]["generated_text"] |
| | .split("```sql")[-1] |
| | .split("```")[0] |
| | .split(";")[0] |
| | .strip() |
| | + ";" |
| | ) |
| | return generated_query |
| |
|