File size: 27,891 Bytes
4cd3373
1e5e637
cf0b63e
 
 
 
4cd3373
cf0b63e
 
 
 
4cd3373
 
 
 
 
 
 
 
 
 
 
 
cf0b63e
4cd3373
 
 
cf0b63e
 
 
 
4cd3373
cf0b63e
 
4cd3373
cf0b63e
 
 
4cd3373
cf0b63e
 
 
 
 
 
 
 
 
 
4cd3373
cf0b63e
 
 
 
 
 
 
4cd3373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf0b63e
 
 
 
 
 
 
 
 
 
4cd3373
 
 
 
 
 
 
 
cf0b63e
4cd3373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf0b63e
 
 
 
 
4cd3373
 
 
19b89fd
caafa78
1096512
 
 
19b89fd
4cd3373
 
caafa78
4cd3373
 
 
19b89fd
4cd3373
caafa78
f3725ae
 
caafa78
 
 
 
 
 
 
 
 
 
4cd3373
caafa78
 
4cd3373
f3725ae
caafa78
 
4cd3373
caafa78
4cd3373
 
 
 
 
 
 
caafa78
4cd3373
caafa78
 
 
 
4cd3373
 
 
 
 
 
caafa78
 
 
4cd3373
 
caafa78
5226654
4cd3373
caafa78
 
 
4cd3373
 
 
 
 
 
 
 
 
19b89fd
4cd3373
19b89fd
 
 
4cd3373
19b89fd
 
 
 
4cd3373
b8af74f
 
d42c09b
 
 
4cd3373
b8af74f
d42c09b
 
b8af74f
d42c09b
4cd3373
d42c09b
 
 
 
 
 
 
 
 
b8af74f
4cd3373
d42c09b
 
 
 
b8af74f
4cd3373
b8af74f
d42c09b
19b89fd
d42c09b
b8af74f
 
 
 
 
 
 
19b89fd
f3725ae
b8af74f
d42c09b
b2315b7
cf0b63e
145cbd2
cf0b63e
c617c61
 
9985acb
 
4139bca
db3b63a
c617c61
 
 
 
9985acb
 
3c6955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19b89fd
3c6955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9985acb
3c6955e
 
9985acb
 
 
 
c617c61
9985acb
 
 
 
 
c617c61
 
 
 
 
9985acb
 
c617c61
3c6955e
c617c61
 
 
 
 
 
3c6955e
c617c61
9985acb
 
c617c61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b65b846
c617c61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b65b846
 
 
382d630
c617c61
 
 
 
9985acb
3c6955e
9985acb
3c6955e
 
9985acb
3c6955e
9985acb
 
3c6955e
 
9985acb
3c6955e
 
9985acb
3c6955e
 
9985acb
3c6955e
9985acb
 
 
 
145cbd2
3c6955e
 
145cbd2
3c6955e
 
145cbd2
3c6955e
145cbd2
3c6955e
 
9985acb
 
 
145cbd2
3c6955e
 
 
145cbd2
3c6955e
145cbd2
9985acb
3c6955e
 
 
d6b8b1b
 
 
3c6955e
6895c9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a75b152
6895c9e
 
 
 
 
 
 
 
 
 
 
 
 
a75b152
6895c9e
 
 
 
a75b152
a2ab4a1
abc41be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a75b152
2dca78b
b392d21
c05048d
cf0b63e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
# universal_lora_trainer_gradio.py
import spaces
import os
import torch
import gradio as gr
import pandas as pd
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
from huggingface_hub import create_repo, upload_folder, hf_hub_download

# transformers optional
try:
    from transformers import AutoTokenizer, AutoModelForCausalLM
    TRANSFORMERS_AVAILABLE = True
except Exception:
    TRANSFORMERS_AVAILABLE = False

# ---------------- Helpers ----------------
def is_hub_repo_like(s):
    return "/" in s and not Path(s).exists()

def download_from_hf(repo_id, filename, token=None):
    token = token or os.environ.get("HF_TOKEN")
    return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token)

# ---------------- Dataset ----------------
class MediaTextDataset(Dataset):
    def __init__(self, source, csv_name="dataset.csv", text_columns=None, max_records=None):
        self.is_hub = is_hub_repo_like(source)
        token = os.environ.get("HF_TOKEN")
        if self.is_hub:
            file_path = download_from_hf(source, csv_name, token)
        else:
            file_path = Path(source) / csv_name

        # fallback to parquet if CSV missing
        if not Path(file_path).exists():
            alt = Path(str(file_path).replace(".csv", ".parquet"))
            if alt.exists():
                file_path = alt
            else:
                raise FileNotFoundError(f"Dataset file not found: {file_path}")

        self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path)
        if max_records:
            self.df = self.df.head(max_records)

        self.text_columns = text_columns or ["short_prompt", "long_prompt"]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        rec = self.df.iloc[i]
        out = {"text": {}}
        for col in self.text_columns:
            out["text"][col] = rec[col] if col in rec else ""
        return out

# ---------------- Model loader ----------------
def load_pipeline_auto(base_model, dtype=torch.float16):
    if "gemma" in base_model.lower():
        if not TRANSFORMERS_AVAILABLE:
            raise RuntimeError("Transformers not installed for LLM support.")
        tokenizer = AutoTokenizer.from_pretrained(base_model)
        model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=dtype)
        return {"model": model, "tokenizer": tokenizer}
    else:
        raise NotImplementedError("Only Gemma LLM supported in this script.")

def find_target_modules(model):
    candidates = ["q_proj", "k_proj", "v_proj", "out_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    names = [n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
    targets = [n.split(".")[-1] for n in names if any(c in n for c in candidates)]
    if not targets:
        targets = [n.split(".")[-1] for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
    return targets

def unwrap_batch(batch, short_col, long_col):
    if isinstance(batch, (list, tuple)):
        ex = batch[0]
        if "text" in ex:
            return ex
        if "short" in ex and "long" in ex:
            return {"text": {short_col: ex.get("short",""), long_col: ex.get("long","")}}
        return {"text": ex}

    if isinstance(batch, dict):
        first_elem = {}
        is_batched = any(isinstance(v, (list, tuple, np.ndarray, torch.Tensor)) for v in batch.values())
        if is_batched:
            for k, v in batch.items():
                try: first = v[0]
                except Exception: first = v
                first_elem[k] = first
            if "text" in first_elem:
                t = first_elem["text"]
                if isinstance(t, (list, tuple)) and len(t) > 0:
                    return {"text": t[0] if isinstance(t[0], dict) else {short_col: t[0], long_col: ""}}
                if isinstance(t, dict): return {"text": t}
                return {"text": {short_col: str(t), long_col: ""}}
            if ("short" in first_elem and "long" in first_elem) or (short_col in first_elem and long_col in first_elem):
                s = first_elem.get(short_col, first_elem.get("short", ""))
                l = first_elem.get(long_col, first_elem.get("long", ""))
                return {"text": {short_col: str(s), long_col: str(l)}}
            return {"text": {short_col: str(first_elem)}}
        if "text" in batch and isinstance(batch["text"], dict):
            return {"text": batch["text"]}
        s = batch.get(short_col, batch.get("short", ""))
        l = batch.get(long_col, batch.get("long", ""))
        return {"text": {short_col: str(s), long_col: str(l)}}
    return {"text": {short_col: str(batch), long_col: ""}}

# ---------------- LoRA Training ----------------
from tempfile import TemporaryDirectory
from accelerate import Accelerator

@spaces.GPU(duration=110)
def train_lora_stream(base_model, dataset_src, csv_name, text_cols,
                      epochs=1, lr=1e-4, r=8, alpha=16, batch_size=1,
                      num_workers=0, max_train_records=None, hf_repo_id=None):

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16 if device=="cuda" else torch.float32
    accelerator = Accelerator()
    pipe = load_pipeline_auto(base_model, dtype=dtype)
    model_obj = pipe["model"]
    tokenizer = pipe["tokenizer"]

    model_obj.train()
    target_modules = find_target_modules(model_obj)
    lora_config = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0)
    lora_module = get_peft_model(model_obj, lora_config)

    dataset = MediaTextDataset(dataset_src, csv_name, text_columns=text_cols, max_records=max_train_records)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    optimizer = torch.optim.AdamW(lora_module.parameters(), lr=lr)
    lora_module, optimizer, loader = accelerator.prepare(lora_module, optimizer, loader)

    max_steps = 150
    step_counter = 0
    logs = []

    yield f"[INFO] Starting LoRA training on {device.upper()} (max {max_steps} steps)...\n", 0.0

    for ep in range(epochs):
        yield f"[DEBUG] Epoch {ep+1}/{epochs}\n", step_counter / max_steps
        for batch in loader:
            if step_counter >= max_steps:
                break

            ex = unwrap_batch(batch, text_cols[0], text_cols[1])
            texts = ex.get("text", {})
            short_text = str(texts.get(text_cols[0], "") or "")
            long_text = str(texts.get(text_cols[1], "") or "")

            enc = tokenizer(short_text, text_pair=long_text, return_tensors="pt",
                            padding="max_length", truncation=True, max_length=512)
            enc = {k: v.to(accelerator.device) for k,v in enc.items()}
            enc["labels"] = enc["input_ids"].clone()

            outputs = lora_module(**enc)
            loss = getattr(outputs, "loss", None)
            if loss is None:
                logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
                loss = torch.nn.functional.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    enc["labels"].view(-1),
                    ignore_index=tokenizer.pad_token_id
                )

            optimizer.zero_grad()
            accelerator.backward(loss)
            optimizer.step()

            logs.append(f"[DEBUG] Step {step_counter}, Loss: {loss.item():.6f}")
            step_counter += 1
            yield "\n".join(logs[-10:]), step_counter / max_steps

        if step_counter >= max_steps:
            break

    # ---------------- Upload to HF ----------------
    HF_TOKEN = os.environ.get("HF_TOKEN")
    if not hf_repo_id:
        raise ValueError("❌ HF repo ID required for upload.")
    if not HF_TOKEN:
        raise ValueError("❌ HF_TOKEN missing.")

    hf_repo_id = hf_repo_id.strip()
    logs.append(f"[INFO] πŸš€ Uploading LoRA to Hugging Face repo: {hf_repo_id}")
    create_repo(hf_repo_id, repo_type="model", exist_ok=True, token=HF_TOKEN)

    with TemporaryDirectory() as tmp_dir:
        lora_module.save_pretrained(tmp_dir)
        upload_folder(folder_path=tmp_dir, repo_id=hf_repo_id, repo_type="model", token=HF_TOKEN)

    link = f"https://huggingface.co/{hf_repo_id}"
    logs.append(f"[INFO] βœ… Uploaded successfully: {link}")
    yield "\n".join(logs), link

# ---------------- CPU Inference ----------------
from peft import PeftModel

from peft import PeftModel
import torch

def generate_long_prompt_cpu(base_model, lora_repo, short_prompt, max_length=200):
    device = torch.device("cpu")

    # Load base model in float32
    pipe = load_pipeline_auto(base_model, dtype=torch.float32)
    base_model_obj = pipe["model"].to(device)
    tokenizer = pipe["tokenizer"]
    base_model_obj.eval()

    # Load LoRA adapter on CPU
    lora_model = PeftModel.from_pretrained(
        base_model_obj,
        lora_repo,
        torch_dtype=torch.float32,
        device_map={"": device}
    )
    lora_model.eval()

    # OPTIONAL: merge LoRA into base model to avoid PEFT runtime issues
    merged_model = lora_model.merge_and_unload()
    merged_model.eval()

    # Tokenize input
    input_ids = tokenizer(short_prompt, return_tensors="pt").input_ids.to(device)

    # Generate safely
    with torch.no_grad():
        outputs = merged_model.generate(
            input_ids,
            max_length=max_length,
            do_sample=True,
            top_p=0.95,
            top_k=50
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)



# ---------------- Gradio UI ----------------
# ---------------- Gradio UI ----------------
import gradio as gr
def run_ui():
    import gradio as gr

    with gr.Blocks(title="Prompt Enhancer Trainer + Inference UI") as demo:
        gr.Markdown("# ✨ Prompt Enhancer Trainer + Inference Playground")
        gr.Markdown("Train, test, and debug your LoRA-enhanced Gemma model easily.Use ZerpGPU to Train else CPU will work for other stuff")
        gr.Markdown("""
πŸ”— **Quick Links:**
- [πŸ“‚ View DataSet (rahul7star/prompt-enhancer-dataset-01)](https://huggingface.co/datasets/rahul7star/prompt-enhancer-dataset-01)
- [πŸ€– View Trained Model (rahul7star/gemma-3-270m-ccebc0)](https://huggingface.co/rahul7star/gemma-3-270m-ccebc0)
""")

        with gr.Tabs():
            # =========================================================
            # 1️⃣ TRAIN LORA TAB
            # =========================================================
            with gr.Tab("Train LoRA"):
                with gr.Row():
                    base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
                    dataset = gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01")
                    csvname = gr.Textbox(label="CSV/Parquet file", value="train-00000-of-00001.csv")
                    short_col = gr.Textbox(label="Short prompt column", value="short_prompt")
                    long_col = gr.Textbox(label="Long prompt column", value="long_prompt")
                    repo = gr.Textbox(label="HF repo to upload LoRA", value="rahul7star/gemma-3-270m-ccebc0")

                with gr.Row():
                    batch_size = gr.Number(value=1, label="Batch size")
                    num_workers = gr.Number(value=0, label="DataLoader num_workers")
                    r = gr.Number(value=8, label="LoRA rank")
                    a = gr.Number(value=16, label="LoRA alpha")
                    ep = gr.Number(value=1, label="Epochs")
                    lr = gr.Number(value=1e-4, label="Learning rate")
                    max_records = gr.Number(value=1000, label="Max training records")

                logs = gr.Textbox(label="Logs (streaming)", lines=25)

                def launch_train(bm, ds, csv, sc, lc, batch, num_w, r_, a_, ep_, lr_, max_rec, repo_):
                    gen = train_lora_stream(
                        bm, ds, csv, [sc, lc],
                        epochs=int(ep_), lr=float(lr_), r=int(r_), alpha=int(a_),
                        batch_size=int(batch), num_workers=int(num_w),
                        max_train_records=int(max_rec), hf_repo_id=repo_
                    )
                    for item in gen:
                        yield item

                btn = gr.Button("πŸš€ Start Training")
                btn.click(
                    fn=launch_train,
                    inputs=[
                        base_model, dataset, csvname, short_col, long_col,
                        batch_size, num_workers, r, a, ep, lr, max_records, repo
                    ],
                    outputs=[logs],
                    queue=True
                )

            # =========================================================
            # 2️⃣ INFERENCE (CPU) TAB
            # =========================================================
            with gr.Tab("Inference (CPU)"):
                inf_base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
                inf_lora_repo = gr.Textbox(label="LoRA HF repo", value="rahul7star/gemma-3-270m-ccebc0")
                short_prompt = gr.Textbox(label="Short prompt")
                long_prompt_out = gr.Textbox(label="Generated long prompt", lines=5)

                inf_btn = gr.Button("πŸ“ Generate Long Prompt")
                inf_btn.click(
                    fn=generate_long_prompt_cpu,
                    inputs=[inf_base_model, inf_lora_repo, short_prompt],
                    outputs=[long_prompt_out]
                )

            # =========================================================
            # 3️⃣ SHOW TRAINABLE PARAMS TAB
            # =========================================================
            with gr.Tab("Show Trainable Params"):
                gr.Markdown("### 🧩 View Trainable Parameters in Your LoRA-Enhanced Model")
                with gr.Row():
                    base_model_name = gr.Textbox(label="Base Model", value="google/gemma-2b-it")
                    check_btn = gr.Button("πŸ” Show Trainable Layers")
                param_output = gr.Textbox(label="Trainable Parameters Info", lines=30)

                def show_trainable_layers(base_model_name):
                    import torch
                    from peft import get_peft_model, LoraConfig
                    from transformers import AutoModelForCausalLM
                    import io
                    import contextlib

                    buf = io.StringIO()
                    print(f"[INFO] Loading base model: {base_model_name}", file=buf)

                    model = AutoModelForCausalLM.from_pretrained(base_model_name)
                    print("[INFO] Initializing LoRA configuration...", file=buf)
                    config = LoraConfig(
                        r=16,
                        lora_alpha=32,
                        target_modules=[
                            "q_proj", "k_proj", "v_proj",
                            "o_proj", "gate_proj", "up_proj", "down_proj"
                        ]
                    )
                    print("[INFO] Applying LoRA adapters...", file=buf)
                    model = get_peft_model(model, config)

                    print("[INFO] Counting trainable parameters...", file=buf)
                    with contextlib.redirect_stdout(buf):
                        model.print_trainable_parameters()

                    print("\n[INFO] Listing all LoRA-injected layers...", file=buf)
                    lora_layers = [name for name, _ in model.named_modules() if "lora" in name.lower()]
                    if not lora_layers:
                        print("⚠️ No LoRA layers detected. Check target_modules configuration.", file=buf)
                    else:
                        print(f"βœ… Found {len(lora_layers)} LoRA-injected submodules:\n", file=buf)
                        for i, layer_name in enumerate(lora_layers[:200]):
                            print(f"  {i+1:03d}. {layer_name}", file=buf)
                        if len(lora_layers) > 200:
                            print(f"...and {len(lora_layers)-200} more layers (truncated)", file=buf)

                    explanation = """
                    ────────────────────────────
### πŸ” What β€œAdapter (90)” Means

When you initialize LoRA on a large model like **Gemma**, the code scans the model 
to find all modules that can receive LoRA layers β€” typically:

- **q_proj, k_proj, v_proj** β†’ Query, Key, Value projections  
- **o_proj / out_proj** β†’ Output of attention  
- **gate_proj, up_proj, down_proj** β†’ Feed-forward MLPs  

Each matching layer gets two small trainable matrices **(A, B)** injected.

So if you see:
> Adapter (90)

That means **90 total submodules** were wrapped with LoRA adapters.

You can view them above πŸ‘†, or print them programmatically with:

```python
for name, module in model.named_modules():
    if "lora" in name.lower():
        print(name)
"""

                    print(explanation, file=buf)
                    return buf.getvalue()
                check_btn.click(show_trainable_layers, inputs=[base_model_name], outputs=[param_output])

   



            # =========================================================
            # 4️⃣ CODE DEBUG TAB
            # =========================================================
            with gr.Tab("Code Debug"):
                gr.Markdown("### 🧩 Code Debug β€” Understand What's Happening Line by Line")
                gr.Markdown("""
#### 🧰 Step-by-Step Breakdown

**1️⃣ `f"[INFO] Loading base model: {base_model}"`**  
β†’ Logs which model is being loaded (e.g., `google/gemma-2b-it`)

**2️⃣ `AutoModelForCausalLM.from_pretrained(base_model)`**  
β†’ Downloads the base Gemma model weights and tokenizer.

**3️⃣ `get_peft_model(model, config)`**  
β†’ Wraps the model with LoRA and injects adapters into `q_proj`, `k_proj`, `v_proj`, etc.

**4️⃣ Expected console output:**
[INFO] Loading base model: google/gemma-2b-it
[INFO] Preparing dataset...
[INFO] Injecting LoRA adapters...
trainable params: 3.5M || all params: 270M || trainable%: 1.3%

**5️⃣ `trainer.train()`**  
β†’ Starts training loop and shows live progress.

**6️⃣ `upload_file(...)`**  
β†’ Uploads all model files to your chosen HF repo.

---

### πŸ” What β€œAdapter (90)” Means
When you initialize LoRA on Gemma, it finds **90 target layers** such as:
- `q_proj`, `k_proj`, `v_proj`
- `o_proj`
- `gate_proj`, `up_proj`, `down_proj`

Each layer gets small trainable matrices (A, B).  
So:
> **Adapter (90)** β†’ 90 modules modified by LoRA.

To list them:
```python
for name, module in model.named_modules():
    if "lora" in name.lower():
        print(name)
""")
             # =========================================================
             # 5️⃣ CODE EXPLAIN TAB
             # =========================================================
            with gr.Tab("Code Explain"):
                explain_md = gr.Markdown("""
### 🧩 Universal Dynamic LoRA Trainer & Inference β€” Code Explanation
This project provides an **end-to-end LoRA fine-tuning and inference system** for language models like **Gemma**, built with **Gradio**, **PEFT**, and **Accelerate**.  
It supports both **training new LoRAs** and **generating text** with existing ones β€” all in a single interface.
---
#### **1️⃣ Imports Overview**
- **Core libs:** `os`, `torch`, `gradio`, `numpy`, `pandas`
- **Training libs:** `peft` (`LoraConfig`, `get_peft_model`), `accelerate` (`Accelerator`)
- **Modeling:** `transformers` (for Gemma base model)
- **Hub integration:** `huggingface_hub` (for uploading adapters)
- **Spaces:** `spaces` β€” for execution within Hugging Face Spaces
---
#### **2️⃣ Dataset Loading**
- Uses a lightweight **MediaTextDataset** class to load:
  - CSV / Parquet files  
  - or directly from a Hugging Face dataset repo
- Expects two columns:  
  `short_prompt` β†’ Input text  
  `long_prompt` β†’ Target expanded text  
- Supports batching, missing-column checks, and configurable max record limits.
---
#### **3️⃣ Model Loading & Preparation**
- Loads **Gemma model and tokenizer** via `AutoModelForCausalLM` and `AutoTokenizer`.
- Automatically detects **target modules** (e.g. `q_proj`, `v_proj`) for LoRA injection.
- Supports `float16` or `bfloat16` precision with `Accelerator` for optimal memory usage.
---
#### **4️⃣ LoRA Training Logic**
- Core formula:  
  \[
  W_{eff} = W + \alpha \times (B @ A)
  \]
- Only **A** and **B** matrices are trainable; base model weights remain frozen.
- Configurable parameters:  
  `r` (rank), `alpha` (scaling), `epochs`, `lr`, `batch_size`
- Training logs stream live in the UI, showing step-by-step loss values.
- After training, the adapter is **saved locally** and **uploaded to Hugging Face Hub**.
---
#### **5️⃣ CPU Inference Mode**
- Runs entirely on **CPU**, no GPU required.
- Loads base Gemma model + trained LoRA weights (`PeftModel.from_pretrained`).
- Optionally merges LoRA with base model.
- Expands the short prompt β†’ long descriptive text using standard generation parameters (e.g., top-p / top-k sampling).
---
#### **6️⃣ LoRA Internals Explained**
- LoRA injects low-rank matrices (A, B) into **attention Linear layers**.
- Example:
  \[
  Q_{new} = Q + \alpha \times (B @ A)
  \]
- Significantly reduces training cost:
  - Memory: ~1–2% of full model
  - Compute: trains faster with minimal GPU load
- Scalable to large models like Gemma 3B / 4B with rank ≀ 16.
---
#### **7️⃣ Gradio UI Structure**
- **Train LoRA Tab:**  
  Configure model, dataset, LoRA parameters, and upload target.  
  Press **πŸš€ Start Training** to stream training logs live.
- **Inference (CPU) Tab:**  
  Type a short prompt β†’ Generates expanded long-form version via trained LoRA.
- **Code Explain Tab:**  
  Detailed breakdown of logic + simulated console output below.
---
### 🧾 Example Log Simulation
```python
print(f"[INFO] Loading base model: {base_model}")
# -> Loads Gemma base model (fp16) on CUDA
# [INFO] Base model google/gemma-3-4b-it loaded successfully
print(f"[INFO] Preparing dataset from: {dataset_path}")
# -> Loads dataset or CSV file
# [DATA] 980 samples loaded, columns: short_prompt, long_prompt
print("[INFO] Initializing LoRA configuration...")
# -> Creates LoraConfig(r=8, alpha=16, target_modules=['q_proj', 'v_proj'])
# [CONFIG] LoRA applied to 96 attention layers
print("[INFO] Starting training loop...")
# [TRAIN] Step 1 | Loss: 2.31
# [TRAIN] Step 50 | Loss: 1.42
# [TRAIN] Step 100 | Loss: 0.91
# [TRAIN] Epoch 1 complete (avg loss: 1.21)
print("[INFO] Saving LoRA adapter...")
# -> Saves safetensors and config locally
print(f"[UPLOAD] Pushing adapter to {hf_repo_id}")
# -> Uploads model to Hugging Face Hub
# [UPLOAD] adapter_model.safetensors (67.7 MB)
# [SUCCESS] LoRA uploaded successfully πŸš€
```
### 🧩 Universal Dynamic LoRA Trainer & Inference β€” Code Explanation
This project provides an **end-to-end LoRA fine-tuning and inference system** for language models like **Gemma**, built with **Gradio**, **PEFT**, and **Accelerate**.  
It supports both **training new LoRAs** and **generating text** with existing ones β€” all in a single interface.

---

#### **1️⃣ Imports Overview**
- **Core libs:** `os`, `torch`, `gradio`, `numpy`, `pandas`
- **Training libs:** `peft` (`LoraConfig`, `get_peft_model`), `accelerate` (`Accelerator`)
- **Modeling:** `transformers` (for Gemma base model)
- **Hub integration:** `huggingface_hub` (for uploading adapters)
- **Spaces:** `spaces` β€” for execution within Hugging Face Spaces

---

#### **2️⃣ Dataset Loading**
- Uses a lightweight **MediaTextDataset** class to load:
  - CSV / Parquet files  
  - or directly from a Hugging Face dataset repo
- Expects two columns:  
  `short_prompt` β†’ Input text  
  `long_prompt` β†’ Target expanded text  
- Supports batching, missing-column checks, and configurable max record limits.

---

#### **3️⃣ Model Loading & Preparation**
- Loads **Gemma model and tokenizer** via `AutoModelForCausalLM` and `AutoTokenizer`.
- Automatically detects **target modules** (e.g. `q_proj`, `v_proj`) for LoRA injection.
- Supports `float16` or `bfloat16` precision with `Accelerator` for optimal memory usage.

---

#### **4️⃣ LoRA Training Logic**
- Core formula:  
  \[
  W_{eff} = W + \alpha \times (B @ A)
  \]
- Only **A** and **B** matrices are trainable; base model weights remain frozen.
- Configurable parameters:  
  `r` (rank), `alpha` (scaling), `epochs`, `lr`, `batch_size`
- Training logs stream live in the UI, showing step-by-step loss values.
- After training, the adapter is **saved locally** and **uploaded to Hugging Face Hub**.

---

#### **5️⃣ CPU Inference Mode**
- Runs entirely on **CPU**, no GPU required.
- Loads base Gemma model + trained LoRA weights (`PeftModel.from_pretrained`).
- Optionally merges LoRA with base model.
- Expands the short prompt β†’ long descriptive text using standard generation parameters (e.g., top-p / top-k sampling).

---

#### **6️⃣ 🧠 What LoRA Does (A & B Injection Explained)**

When you fine-tune a large model (like Gemma or Llama), you’re adjusting **billions** of parameters in large weight matrices.  
LoRA avoids this by **injecting two small low-rank matrices (A and B)** into selected layers instead of modifying the full weight.

---

##### **Step 1: Regular Linear Layer**

\[
y = W x
\]

Here, **W** is a huge matrix (e.g., 4096Γ—4096).

---

##### **Step 2: LoRA Layer Modification**

Instead of updating W directly, LoRA adds a lightweight update:

\[
W' = W + \Delta W
\]
\[
\Delta W = B A
\]

Where:
- **A** ∈ ℝ^(r Γ— d)
- **B** ∈ ℝ^(d Γ— r)
- and **r β‰ͺ d** (e.g., r=8 instead of 4096)

So you’re training only a *tiny fraction* of parameters.

---

##### **Step 3: Where LoRA Gets Injected**

It targets critical sub-layers such as:
- **q_proj, k_proj, v_proj** β†’ Query, Key, Value projections in attention  
- **o_proj / out_proj** β†’ Output projection  
- **gate_proj, up_proj, down_proj** β†’ Feed-forward layers  

When you see:
> `Adapter (90)`

That means 90 total layers (from these modules) were wrapped with LoRA adapters.

---

##### **Step 4: Training Efficiency**

- Base weights (`W`) stay **frozen**
- Only `(A, B)` are **trainable**
- Compute and memory are drastically reduced

| Metric | Full Fine-Tune | LoRA Fine-Tune |
|---------|----------------|----------------|
| Trainable Params | 2B+ | ~3M |
| GPU Memory | 40GB+ | <6GB |
| Time | 10–20 hrs | <1 hr |

---

##### **Step 5: Inference Equation**

At inference time:
\[
y = (W + \alpha \times B A) x
\]

Where **Ξ±** controls the strength of the adapter’s influence.

---

##### **Step 6: Visualization**
Base Layer:
y = W * x

LoRA Layer:
y = (W + B@A) * x
↑ ↑
| └── Small rank-A adapter (trainable)
└──── Small rank-B adapter (trainable)

---

##### **Step 7: Example in Code**

```python
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it")

config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05
)

model = get_peft_model(model, config)
model.print_trainable_parameters()
Expected output:
trainable params: 3,278,848 || all params: 2,040,000,000 || trainable%: 0.16%


""")
    return demo


if __name__ == "__main__":
    run_ui().launch(server_name="0.0.0.0", server_port=7860, share=True)