rahul7star commited on
Commit
4cd3373
Β·
verified Β·
1 Parent(s): 5183059

Update app_gpu.py

Browse files
Files changed (1) hide show
  1. app_gpu.py +144 -76
app_gpu.py CHANGED
@@ -1,31 +1,41 @@
1
- # universal_lora_trainer_gradio_tabs.py
2
- import spaces
3
  import os
4
  import torch
5
  import gradio as gr
6
  import pandas as pd
 
7
  from pathlib import Path
8
  from torch.utils.data import Dataset, DataLoader
9
  from peft import LoraConfig, get_peft_model
10
  from accelerate import Accelerator
11
- from transformers import AutoTokenizer, AutoModelForCausalLM
12
- from huggingface_hub import create_repo, upload_folder
13
- from tempfile import TemporaryDirectory
 
 
 
 
 
 
 
 
 
14
 
15
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
16
 
17
  # ---------------- Dataset ----------------
18
  class MediaTextDataset(Dataset):
19
  def __init__(self, source, csv_name="dataset.csv", text_columns=None, max_records=None):
20
- self.is_hub = "/" in source and not Path(source).exists()
21
  token = os.environ.get("HF_TOKEN")
22
  if self.is_hub:
23
- from huggingface_hub import hf_hub_download
24
- file_path = hf_hub_download(repo_id=source, filename=csv_name, repo_type="dataset", token=token)
25
  else:
26
  file_path = Path(source) / csv_name
27
 
28
- # fallback to parquet
29
  if not Path(file_path).exists():
30
  alt = Path(str(file_path).replace(".csv", ".parquet"))
31
  if alt.exists():
@@ -36,6 +46,7 @@ class MediaTextDataset(Dataset):
36
  self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path)
37
  if max_records:
38
  self.df = self.df.head(max_records)
 
39
  self.text_columns = text_columns or ["short_prompt", "long_prompt"]
40
 
41
  def __len__(self):
@@ -43,18 +54,21 @@ class MediaTextDataset(Dataset):
43
 
44
  def __getitem__(self, i):
45
  rec = self.df.iloc[i]
46
- return {"text": {col: rec[col] if col in rec else "" for col in self.text_columns}}
47
-
48
- # ---------------- Model helpers ----------------
49
- def load_pipeline(base_model, lora_repo=None, dtype=torch.float16):
50
- tokenizer = AutoTokenizer.from_pretrained(base_model)
51
- model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=dtype)
52
- if lora_repo:
53
- from peft import PeftModel
54
- model = PeftModel.from_pretrained(model, lora_repo)
55
- model.to(DEVICE)
56
- model.eval()
57
- return model, tokenizer
 
 
 
58
 
59
  def find_target_modules(model):
60
  candidates = ["q_proj", "k_proj", "v_proj", "out_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
@@ -65,35 +79,57 @@ def find_target_modules(model):
65
  return targets
66
 
67
  def unwrap_batch(batch, short_col, long_col):
 
 
 
 
 
 
 
 
68
  if isinstance(batch, dict):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  s = batch.get(short_col, batch.get("short", ""))
70
  l = batch.get(long_col, batch.get("long", ""))
71
  return {"text": {short_col: str(s), long_col: str(l)}}
72
- if isinstance(batch, (list, tuple)):
73
- ex = batch[0]
74
- return unwrap_batch(ex, short_col, long_col)
75
  return {"text": {short_col: str(batch), long_col: ""}}
76
 
77
- # ---------------- Training ----------------
78
- import spaces
 
79
 
80
  @spaces.GPU(duration=110)
81
  def train_lora_stream(base_model, dataset_src, csv_name, text_cols,
82
  epochs=1, lr=1e-4, r=8, alpha=16, batch_size=1,
83
  num_workers=0, max_train_records=None, hf_repo_id=None):
84
 
85
- if not hf_repo_id:
86
- raise ValueError("❌ HF repo ID is required for upload.")
87
- HF_TOKEN = os.environ.get("HF_TOKEN")
88
- if not HF_TOKEN:
89
- raise ValueError("❌ HF_TOKEN missing.")
90
-
91
- dtype = torch.float16 if DEVICE == "cuda" else torch.float32
92
  accelerator = Accelerator()
93
- tokenizer = AutoTokenizer.from_pretrained(base_model)
94
- model_obj = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=dtype)
95
- model_obj.train()
96
 
 
97
  target_modules = find_target_modules(model_obj)
98
  lora_config = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0)
99
  lora_module = get_peft_model(model_obj, lora_config)
@@ -107,48 +143,81 @@ def train_lora_stream(base_model, dataset_src, csv_name, text_cols,
107
  step_counter = 0
108
  logs = []
109
 
110
- yield "[INFO] Starting LoRA training...\n", 0.0
111
 
112
  for ep in range(epochs):
 
113
  for batch in loader:
114
  if step_counter >= max_steps:
115
  break
 
116
  ex = unwrap_batch(batch, text_cols[0], text_cols[1])
117
- enc = tokenizer(ex["text"][text_cols[0]], text_pair=ex["text"][text_cols[1]],
118
- return_tensors="pt", padding="max_length", truncation=True, max_length=512)
119
- enc = {k: v.to(accelerator.device) for k, v in enc.items()}
 
 
 
 
120
  enc["labels"] = enc["input_ids"].clone()
 
121
  outputs = lora_module(**enc)
122
  loss = getattr(outputs, "loss", None)
123
  if loss is None:
124
  logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
125
- loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)),
126
- enc["labels"].view(-1),
127
- ignore_index=tokenizer.pad_token_id)
 
 
 
128
  optimizer.zero_grad()
129
  accelerator.backward(loss)
130
  optimizer.step()
 
 
131
  step_counter += 1
132
- logs.append(f"Step {step_counter}, Loss {loss.item():.4f}")
133
  yield "\n".join(logs[-10:]), step_counter / max_steps
 
134
  if step_counter >= max_steps:
135
  break
136
 
137
- # Upload to HF
 
 
 
 
 
 
 
 
138
  create_repo(hf_repo_id, repo_type="model", exist_ok=True, token=HF_TOKEN)
 
139
  with TemporaryDirectory() as tmp_dir:
140
  lora_module.save_pretrained(tmp_dir)
141
  upload_folder(folder_path=tmp_dir, repo_id=hf_repo_id, repo_type="model", token=HF_TOKEN)
 
142
  link = f"https://huggingface.co/{hf_repo_id}"
143
  logs.append(f"[INFO] βœ… Uploaded successfully: {link}")
144
  yield "\n".join(logs), link
145
 
146
- # ---------------- Inference ----------------
147
- def generate_long_prompt(base_model, lora_repo, short_prompt, max_length=200):
148
- model, tokenizer = load_pipeline(base_model, lora_repo=lora_repo)
149
- input_ids = tokenizer(short_prompt, return_tensors="pt").input_ids.to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
150
  with torch.no_grad():
151
- outputs = model.generate(input_ids, max_length=max_length, do_sample=True, top_p=0.95, top_k=50)
152
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
153
 
154
  # ---------------- Gradio UI ----------------
@@ -156,25 +225,27 @@ def run_ui():
156
  with gr.Blocks() as demo:
157
  gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer & Inference")
158
 
159
- with gr.Tab("Training"):
160
- base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
161
- dataset = gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01")
162
- csvname = gr.Textbox(label="CSV/Parquet file", value="train-00000-of-00001.csv")
163
- short_col = gr.Textbox(label="Short prompt column", value="short_prompt")
164
- long_col = gr.Textbox(label="Long prompt column", value="long_prompt")
165
- repo = gr.Textbox(label="HF repo ID for LoRA upload (required)", value="rahul7star/gemma-3-270m-ccebc0")
166
-
167
- batch_size = gr.Number(value=1, label="Batch size")
168
- num_workers = gr.Number(value=0, label="DataLoader num_workers")
169
- r = gr.Number(value=8, label="LoRA rank")
170
- a = gr.Number(value=16, label="LoRA alpha")
171
- ep = gr.Number(value=1, label="Epochs")
172
- lr = gr.Number(value=1e-4, label="Learning rate")
173
- max_records = gr.Number(value=1000, label="Max training records")
 
 
174
 
175
  logs = gr.Textbox(label="Logs (streaming)", lines=25)
176
 
177
- def launch_training(bm, ds, csv, sc, lc, batch, num_w, r_, a_, ep_, lr_, max_rec, repo_):
178
  gen = train_lora_stream(
179
  bm, ds, csv, [sc, lc],
180
  epochs=int(ep_), lr=float(lr_), r=int(r_), alpha=int(a_),
@@ -185,23 +256,20 @@ def run_ui():
185
  yield item
186
 
187
  btn = gr.Button("πŸš€ Start Training")
188
- btn.click(fn=launch_training,
189
  inputs=[base_model, dataset, csvname, short_col, long_col,
190
  batch_size, num_workers, r, a, ep, lr, max_records, repo],
191
  outputs=[logs],
192
  queue=True)
193
 
194
- with gr.Tab("Inference"):
195
  inf_base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
196
- inf_lora_repo = gr.Textbox(label="LoRA HF repo ID", value="rahul7star/gemma-3-270m-ccebc0")
197
- short_prompt = gr.Textbox(label="Short Prompt")
198
- long_prompt_out = gr.Textbox(label="Generated Long Prompt", lines=5)
199
-
200
- def run_inference(bm, lora_repo, sp):
201
- return generate_long_prompt(bm, lora_repo, sp)
202
 
203
  inf_btn = gr.Button("πŸ“ Generate Long Prompt")
204
- inf_btn.click(fn=run_inference,
205
  inputs=[inf_base_model, inf_lora_repo, short_prompt],
206
  outputs=[long_prompt_out])
207
 
 
1
+ # universal_lora_trainer_gradio.py
 
2
  import os
3
  import torch
4
  import gradio as gr
5
  import pandas as pd
6
+ import numpy as np
7
  from pathlib import Path
8
  from torch.utils.data import Dataset, DataLoader
9
  from peft import LoraConfig, get_peft_model
10
  from accelerate import Accelerator
11
+ from huggingface_hub import create_repo, upload_folder, hf_hub_download
12
+
13
+ # transformers optional
14
+ try:
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM
16
+ TRANSFORMERS_AVAILABLE = True
17
+ except Exception:
18
+ TRANSFORMERS_AVAILABLE = False
19
+
20
+ # ---------------- Helpers ----------------
21
+ def is_hub_repo_like(s):
22
+ return "/" in s and not Path(s).exists()
23
 
24
+ def download_from_hf(repo_id, filename, token=None):
25
+ token = token or os.environ.get("HF_TOKEN")
26
+ return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token)
27
 
28
  # ---------------- Dataset ----------------
29
  class MediaTextDataset(Dataset):
30
  def __init__(self, source, csv_name="dataset.csv", text_columns=None, max_records=None):
31
+ self.is_hub = is_hub_repo_like(source)
32
  token = os.environ.get("HF_TOKEN")
33
  if self.is_hub:
34
+ file_path = download_from_hf(source, csv_name, token)
 
35
  else:
36
  file_path = Path(source) / csv_name
37
 
38
+ # fallback to parquet if CSV missing
39
  if not Path(file_path).exists():
40
  alt = Path(str(file_path).replace(".csv", ".parquet"))
41
  if alt.exists():
 
46
  self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path)
47
  if max_records:
48
  self.df = self.df.head(max_records)
49
+
50
  self.text_columns = text_columns or ["short_prompt", "long_prompt"]
51
 
52
  def __len__(self):
 
54
 
55
  def __getitem__(self, i):
56
  rec = self.df.iloc[i]
57
+ out = {"text": {}}
58
+ for col in self.text_columns:
59
+ out["text"][col] = rec[col] if col in rec else ""
60
+ return out
61
+
62
+ # ---------------- Model loader ----------------
63
+ def load_pipeline_auto(base_model, dtype=torch.float16):
64
+ if "gemma" in base_model.lower():
65
+ if not TRANSFORMERS_AVAILABLE:
66
+ raise RuntimeError("Transformers not installed for LLM support.")
67
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
68
+ model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=dtype)
69
+ return {"model": model, "tokenizer": tokenizer}
70
+ else:
71
+ raise NotImplementedError("Only Gemma LLM supported in this script.")
72
 
73
  def find_target_modules(model):
74
  candidates = ["q_proj", "k_proj", "v_proj", "out_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
 
79
  return targets
80
 
81
  def unwrap_batch(batch, short_col, long_col):
82
+ if isinstance(batch, (list, tuple)):
83
+ ex = batch[0]
84
+ if "text" in ex:
85
+ return ex
86
+ if "short" in ex and "long" in ex:
87
+ return {"text": {short_col: ex.get("short",""), long_col: ex.get("long","")}}
88
+ return {"text": ex}
89
+
90
  if isinstance(batch, dict):
91
+ first_elem = {}
92
+ is_batched = any(isinstance(v, (list, tuple, np.ndarray, torch.Tensor)) for v in batch.values())
93
+ if is_batched:
94
+ for k, v in batch.items():
95
+ try: first = v[0]
96
+ except Exception: first = v
97
+ first_elem[k] = first
98
+ if "text" in first_elem:
99
+ t = first_elem["text"]
100
+ if isinstance(t, (list, tuple)) and len(t) > 0:
101
+ return {"text": t[0] if isinstance(t[0], dict) else {short_col: t[0], long_col: ""}}
102
+ if isinstance(t, dict): return {"text": t}
103
+ return {"text": {short_col: str(t), long_col: ""}}
104
+ if ("short" in first_elem and "long" in first_elem) or (short_col in first_elem and long_col in first_elem):
105
+ s = first_elem.get(short_col, first_elem.get("short", ""))
106
+ l = first_elem.get(long_col, first_elem.get("long", ""))
107
+ return {"text": {short_col: str(s), long_col: str(l)}}
108
+ return {"text": {short_col: str(first_elem)}}
109
+ if "text" in batch and isinstance(batch["text"], dict):
110
+ return {"text": batch["text"]}
111
  s = batch.get(short_col, batch.get("short", ""))
112
  l = batch.get(long_col, batch.get("long", ""))
113
  return {"text": {short_col: str(s), long_col: str(l)}}
 
 
 
114
  return {"text": {short_col: str(batch), long_col: ""}}
115
 
116
+ # ---------------- LoRA Training ----------------
117
+ from tempfile import TemporaryDirectory
118
+ from accelerate import Accelerator
119
 
120
  @spaces.GPU(duration=110)
121
  def train_lora_stream(base_model, dataset_src, csv_name, text_cols,
122
  epochs=1, lr=1e-4, r=8, alpha=16, batch_size=1,
123
  num_workers=0, max_train_records=None, hf_repo_id=None):
124
 
125
+ device = "cuda" if torch.cuda.is_available() else "cpu"
126
+ dtype = torch.float16 if device=="cuda" else torch.float32
 
 
 
 
 
127
  accelerator = Accelerator()
128
+ pipe = load_pipeline_auto(base_model, dtype=dtype)
129
+ model_obj = pipe["model"]
130
+ tokenizer = pipe["tokenizer"]
131
 
132
+ model_obj.train()
133
  target_modules = find_target_modules(model_obj)
134
  lora_config = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0)
135
  lora_module = get_peft_model(model_obj, lora_config)
 
143
  step_counter = 0
144
  logs = []
145
 
146
+ yield f"[INFO] Starting LoRA training on {device.upper()} (max {max_steps} steps)...\n", 0.0
147
 
148
  for ep in range(epochs):
149
+ yield f"[DEBUG] Epoch {ep+1}/{epochs}\n", step_counter / max_steps
150
  for batch in loader:
151
  if step_counter >= max_steps:
152
  break
153
+
154
  ex = unwrap_batch(batch, text_cols[0], text_cols[1])
155
+ texts = ex.get("text", {})
156
+ short_text = str(texts.get(text_cols[0], "") or "")
157
+ long_text = str(texts.get(text_cols[1], "") or "")
158
+
159
+ enc = tokenizer(short_text, text_pair=long_text, return_tensors="pt",
160
+ padding="max_length", truncation=True, max_length=512)
161
+ enc = {k: v.to(accelerator.device) for k,v in enc.items()}
162
  enc["labels"] = enc["input_ids"].clone()
163
+
164
  outputs = lora_module(**enc)
165
  loss = getattr(outputs, "loss", None)
166
  if loss is None:
167
  logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
168
+ loss = torch.nn.functional.cross_entropy(
169
+ logits.view(-1, logits.size(-1)),
170
+ enc["labels"].view(-1),
171
+ ignore_index=tokenizer.pad_token_id
172
+ )
173
+
174
  optimizer.zero_grad()
175
  accelerator.backward(loss)
176
  optimizer.step()
177
+
178
+ logs.append(f"[DEBUG] Step {step_counter}, Loss: {loss.item():.6f}")
179
  step_counter += 1
 
180
  yield "\n".join(logs[-10:]), step_counter / max_steps
181
+
182
  if step_counter >= max_steps:
183
  break
184
 
185
+ # ---------------- Upload to HF ----------------
186
+ HF_TOKEN = os.environ.get("HF_TOKEN")
187
+ if not hf_repo_id:
188
+ raise ValueError("❌ HF repo ID required for upload.")
189
+ if not HF_TOKEN:
190
+ raise ValueError("❌ HF_TOKEN missing.")
191
+
192
+ hf_repo_id = hf_repo_id.strip()
193
+ logs.append(f"[INFO] πŸš€ Uploading LoRA to Hugging Face repo: {hf_repo_id}")
194
  create_repo(hf_repo_id, repo_type="model", exist_ok=True, token=HF_TOKEN)
195
+
196
  with TemporaryDirectory() as tmp_dir:
197
  lora_module.save_pretrained(tmp_dir)
198
  upload_folder(folder_path=tmp_dir, repo_id=hf_repo_id, repo_type="model", token=HF_TOKEN)
199
+
200
  link = f"https://huggingface.co/{hf_repo_id}"
201
  logs.append(f"[INFO] βœ… Uploaded successfully: {link}")
202
  yield "\n".join(logs), link
203
 
204
+ # ---------------- CPU Inference ----------------
205
+ def generate_long_prompt_cpu(base_model, lora_repo, short_prompt, max_length=200):
206
+ device = torch.device("cpu") # force CPU
207
+ pipe = load_pipeline_auto(base_model)
208
+ model = pipe["model"].to(device)
209
+ tokenizer = pipe["tokenizer"]
210
+
211
+ # Load LoRA adapter from HF
212
+ lora_module = get_peft_model(model, LoraConfig(
213
+ r=8, lora_alpha=16, target_modules=find_target_modules(model)
214
+ ))
215
+ lora_module.load_adapter(lora_repo, device=device, adapter_name="default")
216
+ lora_module.eval()
217
+
218
+ input_ids = tokenizer(short_prompt, return_tensors="pt").input_ids.to(device)
219
  with torch.no_grad():
220
+ outputs = lora_module.generate(input_ids, max_length=max_length, do_sample=True, top_p=0.95, top_k=50)
221
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
222
 
223
  # ---------------- Gradio UI ----------------
 
225
  with gr.Blocks() as demo:
226
  gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer & Inference")
227
 
228
+ with gr.Tab("Train LoRA"):
229
+ with gr.Row():
230
+ base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
231
+ dataset = gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01")
232
+ csvname = gr.Textbox(label="CSV/Parquet file", value="train-00000-of-00001.csv")
233
+ short_col = gr.Textbox(label="Short prompt column", value="short_prompt")
234
+ long_col = gr.Textbox(label="Long prompt column", value="long_prompt")
235
+ repo = gr.Textbox(label="HF repo to upload LoRA", value="rahul7star/gemma-3-270m-ccebc0")
236
+
237
+ with gr.Row():
238
+ batch_size = gr.Number(value=1, label="Batch size")
239
+ num_workers = gr.Number(value=0, label="DataLoader num_workers")
240
+ r = gr.Number(value=8, label="LoRA rank")
241
+ a = gr.Number(value=16, label="LoRA alpha")
242
+ ep = gr.Number(value=1, label="Epochs")
243
+ lr = gr.Number(value=1e-4, label="Learning rate")
244
+ max_records = gr.Number(value=1000, label="Max training records")
245
 
246
  logs = gr.Textbox(label="Logs (streaming)", lines=25)
247
 
248
+ def launch_train(bm, ds, csv, sc, lc, batch, num_w, r_, a_, ep_, lr_, max_rec, repo_):
249
  gen = train_lora_stream(
250
  bm, ds, csv, [sc, lc],
251
  epochs=int(ep_), lr=float(lr_), r=int(r_), alpha=int(a_),
 
256
  yield item
257
 
258
  btn = gr.Button("πŸš€ Start Training")
259
+ btn.click(fn=launch_train,
260
  inputs=[base_model, dataset, csvname, short_col, long_col,
261
  batch_size, num_workers, r, a, ep, lr, max_records, repo],
262
  outputs=[logs],
263
  queue=True)
264
 
265
+ with gr.Tab("Inference (CPU)"):
266
  inf_base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
267
+ inf_lora_repo = gr.Textbox(label="LoRA HF repo", value="rahul7star/gemma-3-270m-ccebc0")
268
+ short_prompt = gr.Textbox(label="Short prompt")
269
+ long_prompt_out = gr.Textbox(label="Generated long prompt", lines=5)
 
 
 
270
 
271
  inf_btn = gr.Button("πŸ“ Generate Long Prompt")
272
+ inf_btn.click(fn=generate_long_prompt_cpu,
273
  inputs=[inf_base_model, inf_lora_repo, short_prompt],
274
  outputs=[long_prompt_out])
275