File size: 7,451 Bytes
a35932e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import os
import base64
from io import BytesIO
import yaml # type: ignore
from PIL import Image # type: ignore
from tqdm import tqdm # type: ignore
import re
import pandas as pd
import numpy as np
np.random.seed(42)


"""======================================================
    Convert PIL images to Base64 encoded strings

    :param pil_image: PIL image
    :return: Re-sized Base64 string
======================================================"""
def convert_to_base64(pil_image):
    buffered = BytesIO()
    pil_image.save(buffered, format="PNG")  # You can change the format if needed
    img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return img_b64 

"""======================================================
        Load config files
======================================================"""
def load_config(filepath):
    try:
        with open(filepath, 'r') as f:
            config = yaml.safe_load(f)
        return config
    except FileNotFoundError:
        print(f"Error: Configuration file '{filepath}' not found.")
        return None
    except yaml.YAMLError as e:
        print(f"Error parsing YAML file '{filepath}': {e}")
        return None
    


""" ======================================================
    Convert individual-poster pair row to LLM conversation format
    
    Args:
        + sample: a row of individual-poster response
    return: conversation format
======================================================"""
def convert_to_conversation(sample, use_image=True):
    poster_id = sample["image_name"]
    poster_path = os.path.abspath(f"../stimuli/{poster_id}.png")

    if use_image:
        # Multimodal Vision-Language Model
        image = Image.open(poster_path)
        conversation = [
            { "role": "user",
            "content" : [
                {"type" : "text",  "text"  : sample["instruction"]},
                {"type" : "image", "image" : image} ]
            },
            { "role" : "assistant",
            "content" : [
                {"type" : "text",  "text"  : str(sample["answer"])} ]
            },
        ]
    else:
        # Text-only Language Model
        conversation = [
            { "role": "user", "content" : sample["instruction"]},
            { "role" : "assistant", "content" : str(sample["answer"])},
        ]
    return { "messages" : conversation }

"""
    Convert dataframe to SFT-compatible trainer dataset

    Args:
        + dataframe: original dataframe
    return: SFT Trainer-compatible dataset
"""
def convert_to_model_trainer_dataset(dataframe, use_image=True):
    trainer_dataset = []
    for idx in tqdm(dataframe.index, desc="dataset conversion"):
        sample = dataframe.loc[idx]
        trainer_dataset.append(convert_to_conversation(sample, use_image))

    return trainer_dataset


"""
    Process and filter persona components (demographics, personality scores, and locus)
    based on user selection flags.

    Args:
        demo_info (str): Raw demographic information text.
        persona_score (str): Text block containing personality trait information.
        locus (str): Text or variable for locus of control information.
        demo_full (bool): If False, only include selected demographic fields.
        include_big5 (bool): Whether to include Big-Five personality scores.
        include_facet (bool): Whether to include facet-level trait scores.
        include_locus (bool): Whether to include locus information.
        train_mode (bool): 
            True, use training mode processing: use sensitive traits only.
            False, use evaluation mode processing: include all specified traits.

    Returns:
        tuple: (demo_info, persona_score, locus)
"""
def process_trait_info(
        demo_info, persona_score, locus,
        demo_full=True, include_big5=True,
        include_facet=True, include_locus=True,
        train_mode=True,
    ):

    # ==============================
    # 1. Demographic filtering
    # ==============================
    if not demo_full and (demo_full is not None):
        fields = [
            "Gender", "Age", "Current Profession", "Race/Ethnicity",
            "Religious/Cultural Group", "Political Affiliation",
            "Highest Education", "Annual Household Income", "Family Status"
        ]

        """ ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        [SPH Request] Most sensitive to ["Gender", "Age", and "Race/Ethnicity".]
        ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ """
        if train_mode:
            # 90% of the time, only include the 3 most sensitive fields
            fields_2 = ["Gender", "Age", "Race/Ethnicity", "Current Profession", "Annual Household Income",]
            # use_top3_sensitive = True if (np.random.rand() <= 0.9) else False
            use_top5_sensitive = True if (np.random.rand() <= 0.9) else False
        else:
            # include all user-specified fields during evaluation
            # use_top3_sensitive = False
            use_top5_sensitive = False


        pattern = r"^(.*?):\s*(.*)$"
        if pd.isna(demo_info) or demo_info is None:
            demo_info = ""
            
        demo_info = demo_info.replace("Demographics:", "").replace("Demographics", "")
        matches = re.findall(pattern, demo_info, flags=re.MULTILINE)

        demo_dict = {k.strip(): v.strip() for k, v in matches if (k.strip().replace(':', '') in fields)}

        demo_info = "Demographics:\n"
        for k in fields:
            # we are not using the top 3 sensitive filtering OR trait k must be in top-3 sensitive
            trait_k_included = (not use_top5_sensitive) or (k in fields_2)
            # demo trait must be available + selected
            if k in demo_dict and trait_k_included:
                demo_info += f"{k}: {demo_dict[k]}\n"
            else:
                demo_info += f"{k}: [Not specified]\n"

    # ==============================
    # 2. Personality trait filtering
    #   => if not available, set
    # ==============================
    if not include_big5 and (persona_score is not None) and (not pd.isna(persona_score)):
        # exclude any big5 traits in the data
        persona_score = re.sub(
            r"Big-Five Trait Scores:.*?(?=(Facet Scores:|Demographics:|$))",
            "",
            persona_score,
            flags=re.DOTALL
        )
    else:
        # include big5 trait
        if (persona_score is None) or (pd.isna(persona_score)):
            persona_score = ""
        if "Big-Five Trait Scores" not in persona_score:
            persona_score = "\nBig-Five Trait Scores: [Not specified]\n" + persona_score

    if not include_facet and (persona_score is not None) and (not pd.isna(persona_score)):
        # exclude any facet traits in the data
        persona_score = re.sub(
            r"Facet Scores:.*?(?=(Demographics:|$))",
            "",
            persona_score,
            flags=re.DOTALL
        )
    else:
        # include facet trait
        if (persona_score is None) or (pd.isna(persona_score)):
            persona_score = ""
        if "Facet Scores" not in persona_score:
            persona_score = persona_score + "\nFacet Scores: [Not specified]\n"

    # ==============================
    # 3. Locus inclusion
    # ==============================
    if not include_locus:
        locus = None

    return demo_info, persona_score, locus