| | import json |
| | from threading import Thread |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
| |
|
| | from .phi2_configuration import Phi2Config |
| | from .phi2_model import Phi2ModelForCausalLM |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) |
| | token_streamer = TextIteratorStreamer(tokenizer) |
| |
|
| | |
| | device = "cuda" |
| | model_config = Phi2Config(**json.load(open("simplified_phi2/config.json"))) |
| | model = Phi2ModelForCausalLM(model_config).to(device) |
| | phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True) |
| |
|
| | phi_model_state_dict = phi_model.state_dict() |
| | model_state_dict = {} |
| | for key, value in phi_model_state_dict.items(): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if key.startswith("transformer"): |
| | key = key.replace("transformer.", "model.") |
| | key = key.replace(".embd.wte.", ".embedding.embeddings.") |
| | key = key.replace(".h.", ".parallel_blocks.") |
| | key = key.replace(".ln.", ".layer_norm.") |
| | key = key.replace(".mixer.Wqkv.", ".multi_head_attention.Wqkv.") |
| | key = key.replace(".mixer.out_proj.", ".multi_head_attention.fc_out.") |
| | else: |
| | key = key.replace("lm_head.ln.", "lm_head_layer_norm.") |
| | key = key.replace("lm_head.linear.", "lm_head_linear.") |
| | model_state_dict[key] = value |
| | model.load_state_dict(model_state_dict) |
| | model.eval() |
| |
|
| | thread = Thread( |
| | target=model.generate, |
| | kwargs=dict( |
| | tokenizer( |
| | "Here is an essay on sea monkeys: ", |
| | return_tensors="pt", |
| | return_attention_mask=False, |
| | ).to(device), |
| | streamer=token_streamer, |
| | max_new_tokens=500, |
| | eos_token_id=tokenizer.eos_token_id, |
| | ), |
| | ) |
| | thread.start() |
| |
|
| | |
| | my_output = "" |
| | for new_token in token_streamer: |
| | my_output += new_token |
| | print(new_token, end="", flush=True) |
| | print() |
| |
|