rahul7star commited on
Commit
dba7fbf
·
verified ·
1 Parent(s): 17cde10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -55
app.py CHANGED
@@ -3,26 +3,24 @@
3
  Universal Dynamic LoRA Trainer (Accelerate + PEFT + Gradio)
4
  - Gemma LLM default
5
  - Auto LoRA target modules
6
- - CSV/Parquet datasets
7
- - Dropdowns for short/long prompt columns, batch size, num_workers
8
- - Live logs (tokenization, forward/backward, step loss)
9
- - Live progress bar
10
  """
11
 
12
- import os, torch, gradio as gr, pandas as pd, numpy as np
13
  from pathlib import Path
14
- from tqdm.auto import tqdm
15
- from huggingface_hub import create_repo, upload_folder, hf_hub_download
16
  from torch.utils.data import Dataset, DataLoader
 
17
  from peft import LoraConfig, get_peft_model
18
  from accelerate import Accelerator
19
  import torch.nn as nn
 
20
 
21
- # Optional LLM support
22
  try:
23
  from transformers import AutoTokenizer, AutoModelForCausalLM
24
  TRANSFORMERS_AVAILABLE = True
25
- except Exception:
26
  TRANSFORMERS_AVAILABLE = False
27
 
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -55,8 +53,9 @@ class MediaTextDataset(Dataset):
55
 
56
  self.text_columns = text_columns or ["short_prompt", "long_prompt"]
57
 
 
58
  print(f"[DEBUG] Loaded dataset: {file_path}, columns: {list(self.df.columns)}")
59
- print(f"[DEBUG] Sample rows:\n{self.df.head(3)}")
60
 
61
  def __len__(self):
62
  return len(self.df)
@@ -66,31 +65,30 @@ class MediaTextDataset(Dataset):
66
  text_data = {col: rec[col] if col in rec else "" for col in self.text_columns}
67
  return {"text": text_data}
68
 
69
- # ---------------- Dynamic pipeline loader ----------------
70
  def load_pipeline_auto(base_model, dtype=torch.float16):
71
- low = base_model.lower()
72
- if "gemma" in low:
73
  if not TRANSFORMERS_AVAILABLE:
74
- raise RuntimeError("Transformers not installed for LLM support.")
75
  print(f"[INFO] Using Gemma LLM for {base_model}")
76
  tokenizer = AutoTokenizer.from_pretrained(base_model)
77
  model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=dtype)
78
  return {"model": model, "tokenizer": tokenizer}
79
  else:
80
- raise NotImplementedError("Only Gemma LLM is implemented for LoRA training in this version.")
81
 
82
- def find_target_modules(model, model_name=None):
83
  candidates = ["q_proj", "k_proj", "v_proj", "out_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
84
  names = [n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
85
  targets = [n.split(".")[-1] for n in names if n.split(".")[-1] in candidates]
86
  if not targets:
87
  targets = [n.split(".")[-1] for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
88
- print(f"[WARNING] No standard attention modules found in {model_name}, using all Linear layers for LoRA")
89
  else:
90
  print(f"[INFO] LoRA target modules detected: {targets}")
91
  return targets
92
 
93
- # ---------------- Training (generator) ----------------
94
  def train_lora_stream(base_model, dataset_src, csv_name, text_cols, output_dir,
95
  epochs=1, lr=1e-4, r=8, alpha=16, batch_size=1, num_workers=0,
96
  max_train_records=None):
@@ -98,17 +96,13 @@ def train_lora_stream(base_model, dataset_src, csv_name, text_cols, output_dir,
98
  pipe = load_pipeline_auto(base_model)
99
  model_obj = pipe["model"]
100
  tokenizer = pipe["tokenizer"]
101
- target_modules = find_target_modules(model_obj, base_model)
102
  lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0)
103
  lora_module = get_peft_model(model_obj, lcfg)
104
 
105
  dataset = MediaTextDataset(dataset_src, csv_name, text_columns=text_cols, max_records=max_train_records)
106
  loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
107
-
108
- # Prepare with accelerator
109
- lora_module, opt, loader = accelerator.prepare(
110
- lora_module, torch.optim.AdamW(lora_module.parameters(), lr=lr), loader
111
- )
112
 
113
  total_steps = epochs * len(loader)
114
  step_counter = 0
@@ -117,7 +111,7 @@ def train_lora_stream(base_model, dataset_src, csv_name, text_cols, output_dir,
117
  yield "[DEBUG] Starting training loop...\n", 0.0
118
 
119
  for ep in range(epochs):
120
- yield f"[DEBUG] Epoch {ep+1}/{epochs}\n", step_counter/total_steps
121
  for i, batch in enumerate(loader):
122
  ex = batch[0]
123
  texts = ex["text"]
@@ -125,7 +119,7 @@ def train_lora_stream(base_model, dataset_src, csv_name, text_cols, output_dir,
125
  # Tokenization
126
  tokens = tokenizer([texts.get("short_prompt",""), texts.get("long_prompt","")],
127
  padding=True, truncation=True, return_tensors="pt").to(DEVICE)
128
- logs.append(f"[DEBUG] Step {step_counter}, tokens input_ids shape: {tokens['input_ids'].shape}")
129
 
130
  # Forward pass
131
  outputs = lora_module(**tokens)
@@ -138,17 +132,17 @@ def train_lora_stream(base_model, dataset_src, csv_name, text_cols, output_dir,
138
  opt.zero_grad()
139
 
140
  step_counter += 1
141
- # Yield last 10 logs + progress
142
- yield "\n".join(logs[-10:]), step_counter/total_steps
143
 
144
  Path(output_dir).mkdir(exist_ok=True)
145
  lora_module.save_pretrained(output_dir)
146
  yield f"[INFO] LoRA saved to {output_dir}\n", 1.0
147
 
148
- # ---------------- Upload ----------------
149
  def upload_adapter(local, repo_id):
150
- token=os.environ.get("HF_TOKEN")
151
- if not token: raise RuntimeError("HF_TOKEN missing")
 
152
  create_repo(repo_id, exist_ok=True)
153
  upload_folder(local, repo_id=repo_id, repo_type="model", token=token)
154
  return f"https://huggingface.co/{repo_id}"
@@ -159,43 +153,46 @@ def run_ui():
159
  gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer (Gemma LLM)")
160
 
161
  with gr.Row():
162
- base_model=gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
163
- dataset=gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01")
164
- csvname=gr.Textbox(label="CSV/Parquet file", value="train.csv")
165
- short_col=gr.Textbox(label="Short prompt column", value="short_prompt")
166
- long_col=gr.Textbox(label="Long prompt column", value="long_prompt")
167
- out=gr.Textbox(label="Output dir", value="./adapter_out")
168
- repo=gr.Textbox(label="Upload HF repo (optional)", value="rahul7star/gemma-3-270m-ccebc0")
 
169
  with gr.Row():
170
  batch_size = gr.Number(value=1, label="Batch size")
171
  num_workers = gr.Number(value=0, label="DataLoader num_workers")
172
- r=gr.Slider(1,64,value=8,label="LoRA rank")
173
- a=gr.Slider(1,64,value=16,label="LoRA alpha")
174
- ep=gr.Number(value=1,label="Epochs")
175
- lr=gr.Number(value=1e-4,label="Learning rate")
176
  max_records = gr.Number(value=1000, label="Max training records")
177
- btn=gr.Button("🚀 Start Training")
178
- logs=gr.Textbox(label="Logs", lines=20)
179
- progress = gr.Progress()
 
180
 
181
- def launch(bm,ds,csv,sc,lc,out_dir,batch,num_w,r_,a_,ep_,lr_,max_rec,repo_):
182
- # Stream logs from generator
183
  for log_text, prog in train_lora_stream(
184
  bm, ds, csv, [sc, lc], out_dir,
185
  int(ep_), float(lr_), int(r_), int(a_),
186
  int(batch), int(num_w), max_train_records=int(max_rec)
187
  ):
188
- yield log_text, prog
189
-
190
- # Upload if repo provided
191
  if repo_:
192
  link = upload_adapter(out_dir, repo_)
193
- yield f"[INFO] Uploaded to {link}", 1.0
 
 
 
 
 
 
194
 
195
- btn.click(launch,
196
- [base_model,dataset,csvname,short_col,long_col,out,batch_size,num_workers,r,a,ep,lr,max_records,repo],
197
- [logs, progress])
198
  return demo
199
 
200
- if __name__=="__main__":
201
  run_ui().launch(server_name="0.0.0.0", server_port=7860, share=True)
 
3
  Universal Dynamic LoRA Trainer (Accelerate + PEFT + Gradio)
4
  - Gemma LLM default
5
  - Auto LoRA target modules
6
+ - CSV/Parquet support
7
+ - Live logs and progress
 
 
8
  """
9
 
10
+ import os, torch, gradio as gr, pandas as pd
11
  from pathlib import Path
 
 
12
  from torch.utils.data import Dataset, DataLoader
13
+ from tqdm.auto import tqdm
14
  from peft import LoraConfig, get_peft_model
15
  from accelerate import Accelerator
16
  import torch.nn as nn
17
+ from huggingface_hub import create_repo, upload_folder, hf_hub_download
18
 
19
+ # Transformers support
20
  try:
21
  from transformers import AutoTokenizer, AutoModelForCausalLM
22
  TRANSFORMERS_AVAILABLE = True
23
+ except:
24
  TRANSFORMERS_AVAILABLE = False
25
 
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
53
 
54
  self.text_columns = text_columns or ["short_prompt", "long_prompt"]
55
 
56
+ # Debug prints
57
  print(f"[DEBUG] Loaded dataset: {file_path}, columns: {list(self.df.columns)}")
58
+ print(f"[DEBUG] Sample row:\n{self.df.head(3)}")
59
 
60
  def __len__(self):
61
  return len(self.df)
 
65
  text_data = {col: rec[col] if col in rec else "" for col in self.text_columns}
66
  return {"text": text_data}
67
 
68
+ # ---------------- Model Loader ----------------
69
  def load_pipeline_auto(base_model, dtype=torch.float16):
70
+ if "gemma" in base_model.lower():
 
71
  if not TRANSFORMERS_AVAILABLE:
72
+ raise RuntimeError("Transformers not installed")
73
  print(f"[INFO] Using Gemma LLM for {base_model}")
74
  tokenizer = AutoTokenizer.from_pretrained(base_model)
75
  model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=dtype)
76
  return {"model": model, "tokenizer": tokenizer}
77
  else:
78
+ raise NotImplementedError("Only Gemma LLM supported currently")
79
 
80
+ def find_target_modules(model):
81
  candidates = ["q_proj", "k_proj", "v_proj", "out_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
82
  names = [n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
83
  targets = [n.split(".")[-1] for n in names if n.split(".")[-1] in candidates]
84
  if not targets:
85
  targets = [n.split(".")[-1] for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
86
+ print("[WARNING] No standard attention modules found, using all Linear layers")
87
  else:
88
  print(f"[INFO] LoRA target modules detected: {targets}")
89
  return targets
90
 
91
+ # ---------------- Training generator ----------------
92
  def train_lora_stream(base_model, dataset_src, csv_name, text_cols, output_dir,
93
  epochs=1, lr=1e-4, r=8, alpha=16, batch_size=1, num_workers=0,
94
  max_train_records=None):
 
96
  pipe = load_pipeline_auto(base_model)
97
  model_obj = pipe["model"]
98
  tokenizer = pipe["tokenizer"]
99
+ target_modules = find_target_modules(model_obj)
100
  lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0)
101
  lora_module = get_peft_model(model_obj, lcfg)
102
 
103
  dataset = MediaTextDataset(dataset_src, csv_name, text_columns=text_cols, max_records=max_train_records)
104
  loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
105
+ lora_module, opt, loader = accelerator.prepare(lora_module, torch.optim.AdamW(lora_module.parameters(), lr=lr), loader)
 
 
 
 
106
 
107
  total_steps = epochs * len(loader)
108
  step_counter = 0
 
111
  yield "[DEBUG] Starting training loop...\n", 0.0
112
 
113
  for ep in range(epochs):
114
+ yield f"[DEBUG] Epoch {ep+1}/{epochs}\n", step_counter / total_steps
115
  for i, batch in enumerate(loader):
116
  ex = batch[0]
117
  texts = ex["text"]
 
119
  # Tokenization
120
  tokens = tokenizer([texts.get("short_prompt",""), texts.get("long_prompt","")],
121
  padding=True, truncation=True, return_tensors="pt").to(DEVICE)
122
+ logs.append(f"[DEBUG] Step {step_counter}, input_ids shape: {tokens['input_ids'].shape}")
123
 
124
  # Forward pass
125
  outputs = lora_module(**tokens)
 
132
  opt.zero_grad()
133
 
134
  step_counter += 1
135
+ yield "\n".join(logs[-10:]), step_counter / total_steps
 
136
 
137
  Path(output_dir).mkdir(exist_ok=True)
138
  lora_module.save_pretrained(output_dir)
139
  yield f"[INFO] LoRA saved to {output_dir}\n", 1.0
140
 
141
+ # ---------------- HF Upload ----------------
142
  def upload_adapter(local, repo_id):
143
+ token = os.environ.get("HF_TOKEN")
144
+ if not token:
145
+ raise RuntimeError("HF_TOKEN missing")
146
  create_repo(repo_id, exist_ok=True)
147
  upload_folder(local, repo_id=repo_id, repo_type="model", token=token)
148
  return f"https://huggingface.co/{repo_id}"
 
153
  gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer (Gemma LLM)")
154
 
155
  with gr.Row():
156
+ base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
157
+ dataset = gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01")
158
+ csvname = gr.Textbox(label="CSV/Parquet file", value="train.csv")
159
+ short_col = gr.Textbox(label="Short prompt column", value="short_prompt")
160
+ long_col = gr.Textbox(label="Long prompt column", value="long_prompt")
161
+ out = gr.Textbox(label="Output dir", value="./adapter_out")
162
+ repo = gr.Textbox(label="Upload HF repo (optional)", value="")
163
+
164
  with gr.Row():
165
  batch_size = gr.Number(value=1, label="Batch size")
166
  num_workers = gr.Number(value=0, label="DataLoader num_workers")
167
+ r = gr.Slider(1, 64, value=8, label="LoRA rank")
168
+ a = gr.Slider(1, 64, value=16, label="LoRA alpha")
169
+ ep = gr.Number(value=1, label="Epochs")
170
+ lr = gr.Number(value=1e-4, label="Learning rate")
171
  max_records = gr.Number(value=1000, label="Max training records")
172
+ btn = gr.Button("🚀 Start Training")
173
+
174
+ logs_box = gr.Textbox(label="Logs", lines=20)
175
+ progress_bar = gr.Progress()
176
 
177
+ def launch(bm, ds, csv, sc, lc, out_dir, batch, num_w, r_, a_, ep_, lr_, max_rec, repo_):
 
178
  for log_text, prog in train_lora_stream(
179
  bm, ds, csv, [sc, lc], out_dir,
180
  int(ep_), float(lr_), int(r_), int(a_),
181
  int(batch), int(num_w), max_train_records=int(max_rec)
182
  ):
183
+ progress_bar.progress = prog
184
+ yield log_text
 
185
  if repo_:
186
  link = upload_adapter(out_dir, repo_)
187
+ yield f"[INFO] Uploaded to {link}"
188
+
189
+ btn.click(
190
+ launch,
191
+ [base_model, dataset, csvname, short_col, long_col, out, batch_size, num_workers, r, a, ep, lr, max_records, repo],
192
+ logs_box
193
+ )
194
 
 
 
 
195
  return demo
196
 
197
+ if __name__ == "__main__":
198
  run_ui().launch(server_name="0.0.0.0", server_port=7860, share=True)