rahul7star commited on
Commit
609c7e3
Β·
verified Β·
1 Parent(s): c9e5f78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -99
app.py CHANGED
@@ -1,113 +1,71 @@
1
- import os, torch, gradio as gr, pandas as pd, numpy as np
2
- from pathlib import Path
3
- from tqdm.auto import tqdm
4
- from huggingface_hub import create_repo, upload_folder, hf_hub_download
5
- from torch.utils.data import Dataset, DataLoader
6
- import torch.nn as nn
7
 
8
- # ============================================================
9
- # 🧠 Intelligent dataset loader
10
- # ============================================================
11
- from datasets import load_dataset, DatasetDict
12
-
13
- def load_dataset_intelligent(source: str, subset: str = None):
14
  """
15
- πŸ” Intelligent dataset loader for CSV, Parquet, or HF Hub.
16
- Detects:
17
- - Local CSV/parquet file
18
- - Local folder containing CSVs
19
- - Hugging Face Hub dataset repo
20
- Returns dict of {split: DataFrame}
21
  """
 
 
 
 
 
 
 
22
 
23
- def try_load_local_csv(path):
24
- if os.path.exists(path) and path.endswith((".csv", ".parquet")):
25
- print(f"πŸ“„ Loading local file: {path}")
26
- return pd.read_parquet(path) if path.endswith(".parquet") else pd.read_csv(path)
27
- return None
 
 
28
 
29
- def try_load_local_folder(path):
30
- if os.path.isdir(path):
31
- csv_files = [f for f in os.listdir(path) if f.endswith((".csv", ".parquet"))]
32
- if csv_files:
33
- print(f"πŸ“ Found folder with {len(csv_files)} data files in: {path}")
34
- dataframes = {}
35
- for file in csv_files:
36
- split_name = "train" if "train" in file else os.path.splitext(file)[0]
37
- fpath = os.path.join(path, file)
38
- df = pd.read_parquet(fpath) if fpath.endswith(".parquet") else pd.read_csv(fpath)
39
- dataframes[split_name] = df
40
- return dataframes
41
- return None
42
 
43
- # 1️⃣ Local file
44
- df = try_load_local_csv(source)
45
- if df is not None:
46
- return {"train": df}
47
 
48
- # 2️⃣ Folder with CSVs
49
- dfs = try_load_local_folder(source)
50
- if dfs is not None:
51
- return dfs
 
 
 
 
 
 
52
 
53
- # 3️⃣ Hugging Face Hub
54
- print(f"🌐 Attempting to load from Hugging Face Hub: {source}")
55
- try:
56
- ds = load_dataset(source, subset or None)
57
- if isinstance(ds, DatasetDict):
58
- print(f"βœ… Loaded HF dataset with splits: {list(ds.keys())}")
59
- return {split: ds[split].to_pandas() for split in ds.keys()}
60
- else:
61
- print("βœ… Loaded single-split HF dataset")
62
- return {"train": ds.to_pandas()}
63
- except Exception as e:
64
- raise FileNotFoundError(f"❌ Could not load dataset: {source}\nError: {str(e)}")
65
 
66
- # ============================================================
67
- # πŸ“ Diffusion Dataset (uses intelligent loader)
68
- # ============================================================
69
- class MediaTextDataset(Dataset):
70
- def __init__(self, source, csv_name="dataset.csv", max_frames=5):
71
- self.source = source
72
- self.max_frames = max_frames
73
- self.data_splits = load_dataset_intelligent(source)
74
 
75
- # Auto-pick train split
76
- self.df = self.data_splits.get("train") or list(self.data_splits.values())[0]
77
- self.root = Path(source) if os.path.isdir(source) else None
78
 
79
- import torchvision.transforms as T
80
- self.img_tf = T.Compose([
81
- T.ToPILImage(), T.Resize((512,512)),
82
- T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)
83
- ])
84
- self.video_tf = T.Compose([
85
- T.ToPILImage(), T.Resize((128,256)),
86
- T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)
87
- ])
88
 
89
- def __len__(self): return len(self.df)
 
 
90
 
91
- def __getitem__(self, i):
92
- import torchvision
93
- rec = self.df.iloc[i]
94
- fname = rec.get("file_name") or rec.get("image") or rec.get("path")
95
- text = rec.get("text") or rec.get("caption") or rec.get("prompt")
96
- p = Path(self.root / fname) if self.root else Path(fname)
97
- if not p.exists():
98
- raise FileNotFoundError(f"Missing file: {p}")
99
 
100
- if p.suffix.lower() in {".jpg",".jpeg",".png",".webp"}:
101
- img = torchvision.io.read_image(str(p))
102
- if isinstance(img, torch.Tensor): img = img.permute(1,2,0).numpy()
103
- return {"type": "image", "image": self.img_tf(img), "caption": text}
104
- elif p.suffix.lower() in {".mp4",".mov",".avi",".mkv"}:
105
- vid,_,_ = torchvision.io.read_video(str(p))
106
- total = len(vid)
107
- if total == 0:
108
- return {"type":"video","frames":torch.zeros((self.max_frames,3,128,256))}
109
- idxs = np.linspace(0,total-1,self.max_frames).round().astype(int)
110
- frames = torch.stack([self.video_tf(vid[j].numpy()) for j in idxs])
111
- return {"type":"video","frames":frames,"caption":text}
112
- else:
113
- raise RuntimeError(f"Unsupported media: {p}")
 
1
+ import os
2
+ import pandas as pd
3
+ from datasets import load_dataset
4
+ import gradio as gr
 
 
5
 
6
+ def load_data(source_path):
 
 
 
 
 
7
  """
8
+ Load dataset from either a local CSV file or a Hugging Face dataset path.
9
+ Automatically detects which type of source to use.
 
 
 
 
10
  """
11
+ try:
12
+ # --- Case 1: Local CSV file ---
13
+ if os.path.exists(source_path):
14
+ print(f"πŸ“‚ Loading local dataset from: {source_path}")
15
+ df = pd.read_csv(source_path)
16
+ print(f"βœ… Loaded {len(df)} rows from local CSV.")
17
+ return df
18
 
19
+ # --- Case 2: Hugging Face dataset ---
20
+ elif "/" in source_path:
21
+ print(f"🌐 Loading Hugging Face dataset: {source_path}")
22
+ dataset = load_dataset(source_path, split="train")
23
+ df = dataset.to_pandas()
24
+ print(f"βœ… Loaded {len(df)} rows from Hugging Face dataset.")
25
+ return df
26
 
27
+ else:
28
+ raise FileNotFoundError("Invalid path: not a local file or HF dataset.")
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ except Exception as e:
31
+ print(f"❌ Error loading data: {e}")
32
+ return pd.DataFrame()
 
33
 
34
+ def summarize_dataset(df):
35
+ """
36
+ Return a brief summary of the dataset for display in Gradio.
37
+ """
38
+ if df.empty:
39
+ return "❌ No data loaded.", ""
40
+
41
+ preview = df.head().to_markdown(index=False)
42
+ info = f"βœ… Loaded {len(df)} rows and {len(df.columns)} columns.\n\n**Columns:** {', '.join(df.columns)}"
43
+ return info, preview
44
 
45
+ def gradio_ui():
46
+ with gr.Blocks(title="Prompt Enhancer Data Loader") as demo:
47
+ gr.Markdown("## 🧠 Intelligent Dataset Loader")
48
+ gr.Markdown("Automatically loads from a local CSV file **or** a Hugging Face dataset repo.")
 
 
 
 
 
 
 
 
49
 
50
+ with gr.Row():
51
+ dataset_path = gr.Textbox(
52
+ label="Enter dataset path (local or HF repo)",
53
+ value="rahul7star/prompt-enhancer-dataset-01",
54
+ placeholder="e.g., /path/to/local.csv or username/dataset-name",
55
+ )
 
 
56
 
57
+ load_btn = gr.Button("πŸš€ Load Dataset")
 
 
58
 
59
+ output_info = gr.Markdown()
60
+ output_preview = gr.Markdown()
 
 
 
 
 
 
 
61
 
62
+ def handle_load(path):
63
+ df = load_data(path)
64
+ return summarize_dataset(df)
65
 
66
+ load_btn.click(handle_load, inputs=[dataset_path], outputs=[output_info, output_preview])
 
 
 
 
 
 
 
67
 
68
+ return demo
69
+
70
+ if __name__ == "__main__":
71
+ gradio_ui().launch(server_name="0.0.0.0", server_port=7860)