update app
Browse files
app.py
CHANGED
|
@@ -29,7 +29,16 @@ import numpy as np
|
|
| 29 |
import random
|
| 30 |
|
| 31 |
import threading
|
| 32 |
-
generation_lock = threading.Lock()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 35 |
|
|
@@ -131,6 +140,11 @@ openers = [
|
|
| 131 |
"Hmmm I think that this message",
|
| 132 |
"Reflecting on the message here",
|
| 133 |
"Considering what this poster is trying to say",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
]
|
| 135 |
openers_generic = [
|
| 136 |
"Hmmm when thinking about",
|
|
@@ -151,27 +165,33 @@ openers_poster_summary = [
|
|
| 151 |
"This poster seems to",
|
| 152 |
"My interpretation of the poster is",
|
| 153 |
"From what this poster shows, it seems to",
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
"The poster gives me the impression that it intends to",
|
| 160 |
]
|
| 161 |
openers_explain = [
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
]
|
| 176 |
|
| 177 |
|
|
@@ -184,89 +204,99 @@ def vlm_response(user_input, history, health_topic,
|
|
| 184 |
political, education, income, family_status,
|
| 185 |
# extraversion, agreeableness, conscientiousness, neuroticism, openness,
|
| 186 |
):
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
""" [NOTE] we have not use `history` for this generation """
|
| 192 |
-
# get uploaded image
|
| 193 |
-
image = Image.open(user_input['files'][0]) if user_input['files'] else None
|
| 194 |
-
image_uploaded = True
|
| 195 |
-
if image is None:
|
| 196 |
-
image = Image.new('RGB', (24,24))
|
| 197 |
-
image_uploaded = False
|
| 198 |
-
# image_b64 = convert_to_base64(image)
|
| 199 |
-
print(health_topic)
|
| 200 |
-
# print("Image uploaded:", image_uploaded)
|
| 201 |
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
|
| 269 |
-
|
|
|
|
| 270 |
if image_uploaded:
|
| 271 |
"""###############################################################
|
| 272 |
Case 1: a health poster is uploaded
|
|
@@ -275,6 +305,8 @@ def vlm_response(user_input, history, health_topic,
|
|
| 275 |
################################################
|
| 276 |
# * IMAGE UNDERSTANDING
|
| 277 |
################################################
|
|
|
|
|
|
|
| 278 |
PROMPT = (
|
| 279 |
f"Describe the content and main message in given heatlh campaign poster and how it's related to {health_topic}. ",
|
| 280 |
"Note that the message could be non-direct or subtle (e.g. irony, fear-driven evoke without explicit texts, etc). Only provide the answer (in 2-4 sentences). ",
|
|
@@ -461,8 +493,15 @@ def vlm_response(user_input, history, health_topic,
|
|
| 461 |
{"type": "text", "text": SYSTEM_PROMPT + USER_PROMPT}
|
| 462 |
]}
|
| 463 |
]
|
| 464 |
-
input_text = tokenizer_aux.apply_chat_template(messages, add_generation_prompt = True)
|
| 465 |
-
inputs = tokenizer_aux(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
# image.convert("RGB"),
|
| 467 |
input_text,
|
| 468 |
add_special_tokens = False,
|
|
@@ -475,7 +514,7 @@ def vlm_response(user_input, history, health_topic,
|
|
| 475 |
# generation with streamer
|
| 476 |
generate_kwargs = dict(
|
| 477 |
**inputs,
|
| 478 |
-
streamer=streamer_aux,
|
| 479 |
max_new_tokens=512,
|
| 480 |
use_cache=True,
|
| 481 |
# min_p=0.3,
|
|
@@ -485,7 +524,7 @@ def vlm_response(user_input, history, health_topic,
|
|
| 485 |
)
|
| 486 |
# separate thread to run generation
|
| 487 |
thread = threading.Thread(
|
| 488 |
-
target=model_aux.generate,
|
| 489 |
kwargs=generate_kwargs
|
| 490 |
)
|
| 491 |
thread.start()
|
|
@@ -494,11 +533,14 @@ def vlm_response(user_input, history, health_topic,
|
|
| 494 |
f"Emulated traits:\n {demo_info}\n" + '='*20 + "\n\n",
|
| 495 |
image_desc + "\n\n"
|
| 496 |
]
|
| 497 |
-
for new_token in streamer_aux:
|
| 498 |
outputs.append(new_token)
|
| 499 |
final_output = ''.join(outputs)
|
| 500 |
yield final_output
|
| 501 |
|
|
|
|
|
|
|
|
|
|
| 502 |
# text representation of final response
|
| 503 |
response = "".join(outputs[2:]) # ignore trait summary & image description
|
| 504 |
print(colored('Traits', 'green'), demo_info)
|
|
@@ -534,8 +576,15 @@ def vlm_response(user_input, history, health_topic,
|
|
| 534 |
{"type": "text", "text": SYSTEM_PROMPT + USER_PROMPT}
|
| 535 |
]}
|
| 536 |
]
|
| 537 |
-
input_text = tokenizer_aux.apply_chat_template(messages, add_generation_prompt = True)
|
| 538 |
-
inputs = tokenizer_aux(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
image.convert("RGB"),
|
| 540 |
input_text,
|
| 541 |
add_special_tokens = False,
|
|
@@ -548,7 +597,7 @@ def vlm_response(user_input, history, health_topic,
|
|
| 548 |
# generation with streamer
|
| 549 |
generate_kwargs = dict(
|
| 550 |
**inputs,
|
| 551 |
-
streamer=streamer_aux,
|
| 552 |
max_new_tokens=512,
|
| 553 |
use_cache=True,
|
| 554 |
min_p=0.85,
|
|
@@ -557,17 +606,19 @@ def vlm_response(user_input, history, health_topic,
|
|
| 557 |
)
|
| 558 |
# separate thread to run generation
|
| 559 |
thread = threading.Thread(
|
| 560 |
-
target=model_aux.generate,
|
| 561 |
kwargs=generate_kwargs
|
| 562 |
)
|
| 563 |
thread.start()
|
| 564 |
# stream out generation
|
| 565 |
# outputs = [image_desc + "\n\n"]
|
| 566 |
outputs += ["\n"]
|
| 567 |
-
for new_token in streamer_aux:
|
| 568 |
outputs.append(new_token)
|
| 569 |
final_output = ''.join(outputs)
|
| 570 |
yield final_output
|
|
|
|
|
|
|
| 571 |
|
| 572 |
|
| 573 |
return answer
|
|
@@ -658,6 +709,15 @@ def vlm_response(user_input, history, health_topic,
|
|
| 658 |
outputs.append(new_token)
|
| 659 |
final_output = ''.join(outputs)
|
| 660 |
yield final_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
|
| 662 |
"""###########################################################################
|
| 663 |
Evaluate a given model (specified in model_cfgs)
|
|
@@ -723,25 +783,25 @@ if __name__ == '__main__':
|
|
| 723 |
### => summarization model
|
| 724 |
### => larger (12b) for better summarization
|
| 725 |
#################################################
|
| 726 |
-
model_aux, tokenizer_aux = FastVisionModel.from_pretrained(
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
)
|
| 730 |
-
FastVisionModel.for_inference(model)
|
| 731 |
-
if "gemma" in cfgs["model"]:
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
|
| 738 |
-
# initialize streamer tokens
|
| 739 |
-
streamer = TextIteratorStreamer(
|
| 740 |
-
|
| 741 |
-
)
|
| 742 |
-
streamer_aux = TextIteratorStreamer(
|
| 743 |
-
|
| 744 |
-
)
|
| 745 |
|
| 746 |
"""=============================================
|
| 747 |
4. User-input Dropdown Traits
|
|
@@ -907,13 +967,6 @@ if __name__ == '__main__':
|
|
| 907 |
### Health Topic ###
|
| 908 |
##########################
|
| 909 |
gr.Markdown("## 2. Please specify the main Health Topic of the poster here:")
|
| 910 |
-
gr.Markdown("""
|
| 911 |
-
#### Notes:
|
| 912 |
-
* Select the main Health Topic of the poster you will upload next.
|
| 913 |
-
* Please make sure the selected Health Topic matches the uploaded poster for best results.
|
| 914 |
-
* If you don’t upload a poster, the model will produce a general response emulating how a person with given traits would say about the selected Health Topic.
|
| 915 |
-
""",
|
| 916 |
-
)
|
| 917 |
# ---- dropdown at ~50% page width and centered ----
|
| 918 |
with gr.Row():
|
| 919 |
with gr.Column(scale=1):
|
|
@@ -928,15 +981,22 @@ if __name__ == '__main__':
|
|
| 928 |
##########################
|
| 929 |
gr.Markdown("## 3. Upload Public Health Poster here (if no poster is uploaded, the model emulates General Response to the topic):")
|
| 930 |
gr.Markdown("""
|
| 931 |
-
####
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
"""
|
| 939 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 940 |
chat = gr.ChatInterface(
|
| 941 |
fn=vlm_response,
|
| 942 |
multimodal=True, # text + image
|
|
@@ -947,18 +1007,20 @@ if __name__ == '__main__':
|
|
| 947 |
political, education, income, family_status,
|
| 948 |
# extraversion, agreeableness, conscientiousness, neuroticism, openness,
|
| 949 |
],
|
| 950 |
-
chatbot=gr.Chatbot(height=
|
| 951 |
autofocus=False,
|
| 952 |
)
|
| 953 |
|
| 954 |
"""=============================================
|
| 955 |
5. Chat Interface Launch
|
| 956 |
============================================="""
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
|
|
|
|
|
|
|
|
|
| 29 |
import random
|
| 30 |
|
| 31 |
import threading
|
| 32 |
+
# generation_lock = threading.Lock()
|
| 33 |
+
|
| 34 |
+
# from transformers import StoppingCriteria, StoppingCriteriaList
|
| 35 |
+
# class StopGenerationCriteria(StoppingCriteria):
|
| 36 |
+
# def __init__(self, stop_event):
|
| 37 |
+
# self.stop_event = stop_event
|
| 38 |
+
|
| 39 |
+
# def __call__(self, input_ids, scores, **kwargs):
|
| 40 |
+
# return self.stop_event.is_set()
|
| 41 |
+
|
| 42 |
|
| 43 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 44 |
|
|
|
|
| 140 |
"Hmmm I think that this message",
|
| 141 |
"Reflecting on the message here",
|
| 142 |
"Considering what this poster is trying to say",
|
| 143 |
+
"Seeing this message makes me think",
|
| 144 |
+
"Thinking about what this poster is communicating",
|
| 145 |
+
"After reading what's on here",
|
| 146 |
+
"Based on what’s written here",
|
| 147 |
+
"After I look at this whole thing",
|
| 148 |
]
|
| 149 |
openers_generic = [
|
| 150 |
"Hmmm when thinking about",
|
|
|
|
| 165 |
"This poster seems to",
|
| 166 |
"My interpretation of the poster is",
|
| 167 |
"From what this poster shows, it seems to",
|
| 168 |
+
"Looking at the poster as a whole, it appears to",
|
| 169 |
+
"Based on the imagery and tone, the poster seems to",
|
| 170 |
+
"Visually, the poster comes across as trying to",
|
| 171 |
+
"To me, this poster is trying to",
|
| 172 |
+
"When I look at this poster, it feels like it aims to",
|
| 173 |
"The poster gives me the impression that it intends to",
|
| 174 |
]
|
| 175 |
openers_explain = [
|
| 176 |
+
"The reason why I think that is because",
|
| 177 |
+
"To explain why I",
|
| 178 |
+
"Well, to explain my thoughts",
|
| 179 |
+
"To put it simply, I feel this way because",
|
| 180 |
+
"My reasoning behind that is",
|
| 181 |
+
"What leads me to that view is",
|
| 182 |
+
"A big part of why I think that is",
|
| 183 |
+
"To give some context for my view,",
|
| 184 |
+
"Here’s why I lean that way:",
|
| 185 |
+
"I see it that way mainly because",
|
| 186 |
+
"Let me explain why I think so",
|
| 187 |
+
"Thinking through it, I realize it's because",
|
| 188 |
+
"To unpack my thinking a bit,",
|
| 189 |
+
"I guess it’s because",
|
| 190 |
+
"The thing that really shapes my view is",
|
| 191 |
+
"It’s pretty much because",
|
| 192 |
+
"A lot of it comes down to",
|
| 193 |
+
"I feel that way mostly because",
|
| 194 |
+
"My thinking comes from the idea that",
|
| 195 |
]
|
| 196 |
|
| 197 |
|
|
|
|
| 204 |
political, education, income, family_status,
|
| 205 |
# extraversion, agreeableness, conscientiousness, neuroticism, openness,
|
| 206 |
):
|
| 207 |
+
# # 1. Initialize Stop Event for this session
|
| 208 |
+
# stop_event = threading.Event()
|
| 209 |
+
# # Create the stopping criteria to pass to the model
|
| 210 |
+
# stopping_criteria = StoppingCriteriaList([StopGenerationCriteria(stop_event)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
+
# 1. Clear any lingering state
|
| 213 |
+
torch.cuda.empty_cache() # Clear GPU memory
|
| 214 |
+
# 2. Initialize Streamers LOCALLY (Fresh for every request)
|
| 215 |
+
# Note: We need to re-initialize these for every single generation call
|
| 216 |
+
# or just once per function call if we share them.
|
| 217 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 218 |
+
# streamer_aux = TextIteratorStreamer(tokenizer_aux, skip_prompt=True, skip_special_tokens=True)
|
| 219 |
+
|
| 220 |
+
""" [NOTE] we have not use `history` for this generation """
|
| 221 |
+
# get uploaded image
|
| 222 |
+
image = Image.open(user_input['files'][0]) if user_input['files'] else None
|
| 223 |
+
image_uploaded = True
|
| 224 |
+
if image is None:
|
| 225 |
+
image = Image.new('RGB', (24,24))
|
| 226 |
+
image_uploaded = False
|
| 227 |
+
# image_b64 = convert_to_base64(image)
|
| 228 |
+
print(health_topic)
|
| 229 |
+
# print("Image uploaded:", image_uploaded)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
#################################################
|
| 234 |
+
# 1. Construct traits from user inputs
|
| 235 |
+
#################################################
|
| 236 |
+
demo_dict = {
|
| 237 |
+
"Gender": gender,
|
| 238 |
+
"Age": age,
|
| 239 |
+
"Current Profession": profession,
|
| 240 |
+
"Race/Ethnicity": race,
|
| 241 |
+
"Religious/Cultural Group": religion,
|
| 242 |
+
"Political Affiliation": political,
|
| 243 |
+
"Highest Education": education,
|
| 244 |
+
"Annual Household Income": income,
|
| 245 |
+
"Family Status": family_status,
|
| 246 |
+
}
|
| 247 |
+
# big5_dict = {
|
| 248 |
+
# "Extraversion": extraversion,
|
| 249 |
+
# "Agreeableness": agreeableness,
|
| 250 |
+
# "Conscientiousness": conscientiousness,
|
| 251 |
+
# "Neuroticism": neuroticism,
|
| 252 |
+
# "Open-Mindedness": openness,
|
| 253 |
+
# }
|
| 254 |
+
|
| 255 |
+
demo_info = ""
|
| 256 |
+
for trait, value in demo_dict.items():
|
| 257 |
+
if value != "Leave Blank": # only add non-blank values
|
| 258 |
+
demo_info += f"{trait}: {value}\n"
|
| 259 |
+
else:
|
| 260 |
+
demo_info += f"{trait}: [Not specified]\n"
|
| 261 |
+
persona_score = ""
|
| 262 |
+
persona_score += "Big-Five Trait Scores:\n"
|
| 263 |
+
# for trait, value in big5_dict.items():
|
| 264 |
+
# persona_score += f"{trait}: {value}\n"
|
| 265 |
+
# no locus of control trait score
|
| 266 |
+
locus = None
|
| 267 |
+
|
| 268 |
+
######################################################################################
|
| 269 |
+
# 1*. modify trait info based on trait selection setings
|
| 270 |
+
# demo_full: wheter include full demographic traits or only selected ones
|
| 271 |
+
# include_big5, include_facet, include_locus: include big5 / facet / locus of control traits or not
|
| 272 |
+
# format: <trait>: <value> if available; else <trait>: [Not specified]
|
| 273 |
+
######################################################################################
|
| 274 |
+
demo_info, persona_score, locus = process_trait_info(
|
| 275 |
+
demo_info, persona_score, locus,
|
| 276 |
+
demo_full=False, include_big5=True,
|
| 277 |
+
include_facet=False, include_locus=False,
|
| 278 |
+
train_mode=False,
|
| 279 |
+
)
|
| 280 |
+
# print(demo_info)
|
| 281 |
+
# print(persona_score)
|
| 282 |
+
|
| 283 |
+
###############################################
|
| 284 |
+
### Add style variability ###
|
| 285 |
+
###############################################
|
| 286 |
+
style_hint = random.choice(style_variants) # increase style variant
|
| 287 |
+
lexical_hint = random.choice(lexical_flavors) # increase lexical variant
|
| 288 |
+
opening_phrase = random.choice(openers) # increase opening variant
|
| 289 |
+
opening_generic = random.choice(openers_generic) # increase opening variant
|
| 290 |
+
opening_poster = random.choice(openers_poster_summary) # poster summary variation
|
| 291 |
+
opening_explain = random.choice(openers_explain) # thought explanation
|
| 292 |
+
print('Style:', style_hint)
|
| 293 |
+
print('Lexical:', lexical_hint)
|
| 294 |
+
print('Opening:', opening_phrase)
|
| 295 |
+
print('Generic opening:', opening_generic)
|
| 296 |
|
| 297 |
|
| 298 |
+
# Wrap the GENERATION logic in try/finally to handle cleanup
|
| 299 |
+
try:
|
| 300 |
if image_uploaded:
|
| 301 |
"""###############################################################
|
| 302 |
Case 1: a health poster is uploaded
|
|
|
|
| 305 |
################################################
|
| 306 |
# * IMAGE UNDERSTANDING
|
| 307 |
################################################
|
| 308 |
+
yield "Analyzing image content..." # UI Feedback
|
| 309 |
+
|
| 310 |
PROMPT = (
|
| 311 |
f"Describe the content and main message in given heatlh campaign poster and how it's related to {health_topic}. ",
|
| 312 |
"Note that the message could be non-direct or subtle (e.g. irony, fear-driven evoke without explicit texts, etc). Only provide the answer (in 2-4 sentences). ",
|
|
|
|
| 493 |
{"type": "text", "text": SYSTEM_PROMPT + USER_PROMPT}
|
| 494 |
]}
|
| 495 |
]
|
| 496 |
+
# input_text = tokenizer_aux.apply_chat_template(messages, add_generation_prompt = True)
|
| 497 |
+
# inputs = tokenizer_aux(
|
| 498 |
+
# # image.convert("RGB"),
|
| 499 |
+
# input_text,
|
| 500 |
+
# add_special_tokens = False,
|
| 501 |
+
# return_tensors = "pt",
|
| 502 |
+
# ).to(device)
|
| 503 |
+
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
|
| 504 |
+
inputs = tokenizer(
|
| 505 |
# image.convert("RGB"),
|
| 506 |
input_text,
|
| 507 |
add_special_tokens = False,
|
|
|
|
| 514 |
# generation with streamer
|
| 515 |
generate_kwargs = dict(
|
| 516 |
**inputs,
|
| 517 |
+
streamer=streamer, # streamer_aux,
|
| 518 |
max_new_tokens=512,
|
| 519 |
use_cache=True,
|
| 520 |
# min_p=0.3,
|
|
|
|
| 524 |
)
|
| 525 |
# separate thread to run generation
|
| 526 |
thread = threading.Thread(
|
| 527 |
+
target=model.generate, # model_aux.generate,
|
| 528 |
kwargs=generate_kwargs
|
| 529 |
)
|
| 530 |
thread.start()
|
|
|
|
| 533 |
f"Emulated traits:\n {demo_info}\n" + '='*20 + "\n\n",
|
| 534 |
image_desc + "\n\n"
|
| 535 |
]
|
| 536 |
+
for new_token in streamer: # streamer_aux:
|
| 537 |
outputs.append(new_token)
|
| 538 |
final_output = ''.join(outputs)
|
| 539 |
yield final_output
|
| 540 |
|
| 541 |
+
# Ensure thread finishes
|
| 542 |
+
thread.join()
|
| 543 |
+
|
| 544 |
# text representation of final response
|
| 545 |
response = "".join(outputs[2:]) # ignore trait summary & image description
|
| 546 |
print(colored('Traits', 'green'), demo_info)
|
|
|
|
| 576 |
{"type": "text", "text": SYSTEM_PROMPT + USER_PROMPT}
|
| 577 |
]}
|
| 578 |
]
|
| 579 |
+
# input_text = tokenizer_aux.apply_chat_template(messages, add_generation_prompt = True)
|
| 580 |
+
# inputs = tokenizer_aux(
|
| 581 |
+
# image.convert("RGB"),
|
| 582 |
+
# input_text,
|
| 583 |
+
# add_special_tokens = False,
|
| 584 |
+
# return_tensors = "pt",
|
| 585 |
+
# ).to(device)
|
| 586 |
+
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
|
| 587 |
+
inputs = tokenizer(
|
| 588 |
image.convert("RGB"),
|
| 589 |
input_text,
|
| 590 |
add_special_tokens = False,
|
|
|
|
| 597 |
# generation with streamer
|
| 598 |
generate_kwargs = dict(
|
| 599 |
**inputs,
|
| 600 |
+
streamer=streamer, # streamer_aux,
|
| 601 |
max_new_tokens=512,
|
| 602 |
use_cache=True,
|
| 603 |
min_p=0.85,
|
|
|
|
| 606 |
)
|
| 607 |
# separate thread to run generation
|
| 608 |
thread = threading.Thread(
|
| 609 |
+
target=model.generate, # model_aux.generate,
|
| 610 |
kwargs=generate_kwargs
|
| 611 |
)
|
| 612 |
thread.start()
|
| 613 |
# stream out generation
|
| 614 |
# outputs = [image_desc + "\n\n"]
|
| 615 |
outputs += ["\n"]
|
| 616 |
+
for new_token in streamer: # streamer_aux:
|
| 617 |
outputs.append(new_token)
|
| 618 |
final_output = ''.join(outputs)
|
| 619 |
yield final_output
|
| 620 |
+
|
| 621 |
+
thread.join()
|
| 622 |
|
| 623 |
|
| 624 |
return answer
|
|
|
|
| 709 |
outputs.append(new_token)
|
| 710 |
final_output = ''.join(outputs)
|
| 711 |
yield final_output
|
| 712 |
+
thread.join()
|
| 713 |
+
|
| 714 |
+
except GeneratorExit:
|
| 715 |
+
print("User disconnected. Waiting for generation to complete...")
|
| 716 |
+
finally:
|
| 717 |
+
# Ensure cleanup happens even on normal finish or errors
|
| 718 |
+
if thread is not None and thread.is_alive():
|
| 719 |
+
thread.join()
|
| 720 |
+
torch.cuda.empty_cache()
|
| 721 |
|
| 722 |
"""###########################################################################
|
| 723 |
Evaluate a given model (specified in model_cfgs)
|
|
|
|
| 783 |
### => summarization model
|
| 784 |
### => larger (12b) for better summarization
|
| 785 |
#################################################
|
| 786 |
+
# model_aux, tokenizer_aux = FastVisionModel.from_pretrained(
|
| 787 |
+
# model_name=cfgs["model_summarize"],
|
| 788 |
+
# load_in_4bit=True,
|
| 789 |
+
# )
|
| 790 |
+
# FastVisionModel.for_inference(model)
|
| 791 |
+
# if "gemma" in cfgs["model"]:
|
| 792 |
+
# # gemma-specific tokenizer chat template
|
| 793 |
+
# tokenizer_aux = get_chat_template(
|
| 794 |
+
# tokenizer_aux,
|
| 795 |
+
# chat_template = "gemma-3",
|
| 796 |
+
# )
|
| 797 |
|
| 798 |
+
# # initialize streamer tokens
|
| 799 |
+
# streamer = TextIteratorStreamer(
|
| 800 |
+
# tokenizer, skip_prompt=True, skip_special_tokens=True
|
| 801 |
+
# )
|
| 802 |
+
# streamer_aux = TextIteratorStreamer(
|
| 803 |
+
# tokenizer_aux, skip_prompt=True, skip_special_tokens=True
|
| 804 |
+
# )
|
| 805 |
|
| 806 |
"""=============================================
|
| 807 |
4. User-input Dropdown Traits
|
|
|
|
| 967 |
### Health Topic ###
|
| 968 |
##########################
|
| 969 |
gr.Markdown("## 2. Please specify the main Health Topic of the poster here:")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 970 |
# ---- dropdown at ~50% page width and centered ----
|
| 971 |
with gr.Row():
|
| 972 |
with gr.Column(scale=1):
|
|
|
|
| 981 |
##########################
|
| 982 |
gr.Markdown("## 3. Upload Public Health Poster here (if no poster is uploaded, the model emulates General Response to the topic):")
|
| 983 |
gr.Markdown("""
|
| 984 |
+
#### ▶️ Use Case 1: Poster-Based Response
|
| 985 |
+
+ Upload **only one** poster image — the first file is the one processed.
|
| 986 |
+
+ The model has **no memory**, so re-upload the image for each new request.
|
| 987 |
+
+ Must choose a **Health Topic** that matches the poster content for best results.
|
| 988 |
+
+ No text prompt is needed: upload the poster and click **Submit**.
|
| 989 |
+
#### ▶️ Use Case 2: General Response (No Poster)
|
| 990 |
+
+ Simply select a Health Topic and click **Send**.
|
| 991 |
"""
|
| 992 |
)
|
| 993 |
+
gr.Markdown("""
|
| 994 |
+
### 📘 Important Notes
|
| 995 |
+
- ⚠️ **Do not interrupt the generation process.** Stopping midway can cause backend issues. Please allow the response to complete.
|
| 996 |
+
- 🏷️ Before uploading a poster, select its **corresponding health topic**.
|
| 997 |
+
- 🎯 For the best experience, ensure the **topic accurately matches the poster content**.
|
| 998 |
+
- 🧩 If you choose not to upload a poster, the model will produce a **general, trait-conditioned response** for the selected topic.
|
| 999 |
+
""")
|
| 1000 |
chat = gr.ChatInterface(
|
| 1001 |
fn=vlm_response,
|
| 1002 |
multimodal=True, # text + image
|
|
|
|
| 1007 |
political, education, income, family_status,
|
| 1008 |
# extraversion, agreeableness, conscientiousness, neuroticism, openness,
|
| 1009 |
],
|
| 1010 |
+
chatbot=gr.Chatbot(height=500), # height=330
|
| 1011 |
autofocus=False,
|
| 1012 |
)
|
| 1013 |
|
| 1014 |
"""=============================================
|
| 1015 |
5. Chat Interface Launch
|
| 1016 |
============================================="""
|
| 1017 |
+
interface.queue(
|
| 1018 |
+
max_size=20,
|
| 1019 |
+
default_concurrency_limit=1,
|
| 1020 |
+
).launch(
|
| 1021 |
+
share=True,
|
| 1022 |
+
max_threads=1,
|
| 1023 |
+
# show_error=True,
|
| 1024 |
+
# prevent_thread_lock=False,
|
| 1025 |
+
# debug=True,
|
| 1026 |
+
)
|
app.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH -c 16 # 16 CPUs
|
| 3 |
+
#SBATCH --mem=32g # 32 GB RAM
|
| 4 |
+
#SBATCH --gres=gpu:rtxa5000:1 # 1 GPU (A6000)
|
| 5 |
+
#SBATCH --time=3-00:00:00 # 8 days
|
| 6 |
+
#SBATCH --account=gamma
|
| 7 |
+
#SBATCH --partition=gamma
|
| 8 |
+
#SBATCH --qos=gamma-huge-long
|
| 9 |
+
#SBATCH --output=/fs/nexus-projects/health_sim_ai/src_hf_deploy/app_logs/app_%j.out
|
| 10 |
+
|
| 11 |
+
export HOME=/fs/nexus-projects/health_sim_ai
|
| 12 |
+
cd /fs/nexus-projects/health_sim_ai
|
| 13 |
+
source venvs/llm/bin/activate
|
| 14 |
+
cd src_hf_deploy
|
| 15 |
+
python -u app.py
|
| 16 |
+
# python inference_pred_llm.py
|
| 17 |
+
# python inference_rec_llm.py
|