File size: 7,451 Bytes
a35932e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import os
import base64
from io import BytesIO
import yaml # type: ignore
from PIL import Image # type: ignore
from tqdm import tqdm # type: ignore
import re
import pandas as pd
import numpy as np
np.random.seed(42)
"""======================================================
Convert PIL images to Base64 encoded strings
:param pil_image: PIL image
:return: Re-sized Base64 string
======================================================"""
def convert_to_base64(pil_image):
buffered = BytesIO()
pil_image.save(buffered, format="PNG") # You can change the format if needed
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_b64
"""======================================================
Load config files
======================================================"""
def load_config(filepath):
try:
with open(filepath, 'r') as f:
config = yaml.safe_load(f)
return config
except FileNotFoundError:
print(f"Error: Configuration file '{filepath}' not found.")
return None
except yaml.YAMLError as e:
print(f"Error parsing YAML file '{filepath}': {e}")
return None
""" ======================================================
Convert individual-poster pair row to LLM conversation format
Args:
+ sample: a row of individual-poster response
return: conversation format
======================================================"""
def convert_to_conversation(sample, use_image=True):
poster_id = sample["image_name"]
poster_path = os.path.abspath(f"../stimuli/{poster_id}.png")
if use_image:
# Multimodal Vision-Language Model
image = Image.open(poster_path)
conversation = [
{ "role": "user",
"content" : [
{"type" : "text", "text" : sample["instruction"]},
{"type" : "image", "image" : image} ]
},
{ "role" : "assistant",
"content" : [
{"type" : "text", "text" : str(sample["answer"])} ]
},
]
else:
# Text-only Language Model
conversation = [
{ "role": "user", "content" : sample["instruction"]},
{ "role" : "assistant", "content" : str(sample["answer"])},
]
return { "messages" : conversation }
"""
Convert dataframe to SFT-compatible trainer dataset
Args:
+ dataframe: original dataframe
return: SFT Trainer-compatible dataset
"""
def convert_to_model_trainer_dataset(dataframe, use_image=True):
trainer_dataset = []
for idx in tqdm(dataframe.index, desc="dataset conversion"):
sample = dataframe.loc[idx]
trainer_dataset.append(convert_to_conversation(sample, use_image))
return trainer_dataset
"""
Process and filter persona components (demographics, personality scores, and locus)
based on user selection flags.
Args:
demo_info (str): Raw demographic information text.
persona_score (str): Text block containing personality trait information.
locus (str): Text or variable for locus of control information.
demo_full (bool): If False, only include selected demographic fields.
include_big5 (bool): Whether to include Big-Five personality scores.
include_facet (bool): Whether to include facet-level trait scores.
include_locus (bool): Whether to include locus information.
train_mode (bool):
True, use training mode processing: use sensitive traits only.
False, use evaluation mode processing: include all specified traits.
Returns:
tuple: (demo_info, persona_score, locus)
"""
def process_trait_info(
demo_info, persona_score, locus,
demo_full=True, include_big5=True,
include_facet=True, include_locus=True,
train_mode=True,
):
# ==============================
# 1. Demographic filtering
# ==============================
if not demo_full and (demo_full is not None):
fields = [
"Gender", "Age", "Current Profession", "Race/Ethnicity",
"Religious/Cultural Group", "Political Affiliation",
"Highest Education", "Annual Household Income", "Family Status"
]
""" ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
[SPH Request] Most sensitive to ["Gender", "Age", and "Race/Ethnicity".]
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ """
if train_mode:
# 90% of the time, only include the 3 most sensitive fields
fields_2 = ["Gender", "Age", "Race/Ethnicity", "Current Profession", "Annual Household Income",]
# use_top3_sensitive = True if (np.random.rand() <= 0.9) else False
use_top5_sensitive = True if (np.random.rand() <= 0.9) else False
else:
# include all user-specified fields during evaluation
# use_top3_sensitive = False
use_top5_sensitive = False
pattern = r"^(.*?):\s*(.*)$"
if pd.isna(demo_info) or demo_info is None:
demo_info = ""
demo_info = demo_info.replace("Demographics:", "").replace("Demographics", "")
matches = re.findall(pattern, demo_info, flags=re.MULTILINE)
demo_dict = {k.strip(): v.strip() for k, v in matches if (k.strip().replace(':', '') in fields)}
demo_info = "Demographics:\n"
for k in fields:
# we are not using the top 3 sensitive filtering OR trait k must be in top-3 sensitive
trait_k_included = (not use_top5_sensitive) or (k in fields_2)
# demo trait must be available + selected
if k in demo_dict and trait_k_included:
demo_info += f"{k}: {demo_dict[k]}\n"
else:
demo_info += f"{k}: [Not specified]\n"
# ==============================
# 2. Personality trait filtering
# => if not available, set
# ==============================
if not include_big5 and (persona_score is not None) and (not pd.isna(persona_score)):
# exclude any big5 traits in the data
persona_score = re.sub(
r"Big-Five Trait Scores:.*?(?=(Facet Scores:|Demographics:|$))",
"",
persona_score,
flags=re.DOTALL
)
else:
# include big5 trait
if (persona_score is None) or (pd.isna(persona_score)):
persona_score = ""
if "Big-Five Trait Scores" not in persona_score:
persona_score = "\nBig-Five Trait Scores: [Not specified]\n" + persona_score
if not include_facet and (persona_score is not None) and (not pd.isna(persona_score)):
# exclude any facet traits in the data
persona_score = re.sub(
r"Facet Scores:.*?(?=(Demographics:|$))",
"",
persona_score,
flags=re.DOTALL
)
else:
# include facet trait
if (persona_score is None) or (pd.isna(persona_score)):
persona_score = ""
if "Facet Scores" not in persona_score:
persona_score = persona_score + "\nFacet Scores: [Not specified]\n"
# ==============================
# 3. Locus inclusion
# ==============================
if not include_locus:
locus = None
return demo_info, persona_score, locus |