Werli commited on
Commit
2937a35
·
verified ·
1 Parent(s): c85588d

Delete modules/tag_enhancer.py

Browse files
Files changed (1) hide show
  1. modules/tag_enhancer.py +0 -52
modules/tag_enhancer.py DELETED
@@ -1,52 +0,0 @@
1
- import gradio as gr
2
- import re,torch
3
- from transformers import pipeline,AutoTokenizer,AutoModelForSeq2SeqLM
4
-
5
- device = "cpu" if torch.cuda.is_available() else "cpu" # Switched to CPU since we are using HF with no GPU
6
-
7
- def load_models():
8
- try:
9
- enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
10
- enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
11
- model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
12
- tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
13
- model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
14
- enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
15
- except Exception as e:
16
- print(e)
17
- enhancer_medium = enhancer_long = enhancer_flux = None
18
- return enhancer_medium, enhancer_long, enhancer_flux
19
- enhancer_medium, enhancer_long, enhancer_flux = load_models()
20
-
21
- def summarize_prompt(input_prompt, model_choice):
22
- if model_choice == "Medium":
23
- result = enhancer_medium("Enhance the description: " + input_prompt)
24
- summarized_text = result[0]['summary_text']
25
-
26
- pattern = r'^.*?of\s+(.*?(?:\.|$))'
27
- match = re.match(pattern, summarized_text, re.IGNORECASE | re.DOTALL)
28
-
29
- if match:
30
- remaining_text = summarized_text[match.end():].strip()
31
- modified_sentence = match.group(1).capitalize()
32
- summarized_text = modified_sentence + ' ' + remaining_text
33
- elif model_choice == "Flux":
34
- result = enhancer_flux("Enhance prompt: " + input_prompt, max_length=256)
35
- summarized_text = result[0]['generated_text']
36
- else: # Long
37
- result = enhancer_long("Enhance the description: " + input_prompt)
38
- summarized_text = result[0]['summary_text']
39
-
40
- return summarized_text
41
-
42
- def prompt_summarizer(character: str, series: str, general: str, model_choice: str):
43
- characters = character.split(",") if character else []
44
- serieses = series.split(",") if series else []
45
- generals = general.split(",") if general else []
46
- tags = characters + serieses + generals
47
- cprompt = ",".join(tags) if tags else ""
48
-
49
- output = summarize_prompt(cprompt, model_choice)
50
- prompt = cprompt + ", " + output
51
-
52
- return prompt, gr.update(interactive=True), gr.update(interactive=True)