Spaces:
Running
Running
| """ | |
| AIFinder Data Loader | |
| Downloads and parses HuggingFace datasets, extracts assistant responses, | |
| and labels them with is_ai, provider, and model. | |
| """ | |
| import os | |
| import re | |
| import time | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| from config import ( | |
| DATASET_REGISTRY, | |
| DEEPSEEK_AM_DATASETS, | |
| MAX_SAMPLES_PER_PROVIDER, | |
| ) | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| def _parse_msg(msg): | |
| """Parse a message that may be a dict or a JSON string.""" | |
| if isinstance(msg, dict): | |
| return msg | |
| if isinstance(msg, str): | |
| try: | |
| import json as _json | |
| parsed = _json.loads(msg) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except (ValueError, Exception): | |
| pass | |
| return {} | |
| def _extract_response_only(content): | |
| """Extract only the final response, stripping CoT blocks. | |
| Returns only the text after </think> or </thinking> if present, | |
| otherwise returns the full content. | |
| """ | |
| if not content: | |
| return "" | |
| think_match = re.search(r"</?think(?:ing)?>(.*)$", content, re.DOTALL) | |
| if think_match: | |
| response = think_match.group(1).strip() | |
| if response: | |
| return response | |
| return content | |
| def _extract_assistant_texts_from_conversations(rows): | |
| """Extract individual assistant messages from conversation datasets. | |
| Returns one text per assistant turn (not concatenated) for cleaner samples. | |
| Only extracts the response portion (after </think> if present). | |
| """ | |
| texts = [] | |
| for row in rows: | |
| convos = row.get("conversations") | |
| if convos is None or (hasattr(convos, "__len__") and len(convos) == 0): | |
| convos = row.get("messages") | |
| if convos is None or (hasattr(convos, "__len__") and len(convos) == 0): | |
| convos = [] | |
| for msg in convos: | |
| msg = _parse_msg(msg) | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| if role in ("assistant", "gpt", "model") and content: | |
| response_only = _extract_response_only(content) | |
| if response_only: | |
| texts.append(response_only) | |
| return texts | |
| def _extract_from_am_dataset(row): | |
| """Extract individual assistant texts from a-m-team format. | |
| Only extracts the response portion (after </think> if present). | |
| """ | |
| messages = row.get("messages") or row.get("conversations") or [] | |
| texts = [] | |
| for msg in messages: | |
| role = msg.get("role", "") if isinstance(msg, dict) else "" | |
| content = msg.get("content", "") if isinstance(msg, dict) else "" | |
| if role == "assistant" and content: | |
| response_only = _extract_response_only(content) | |
| if response_only: | |
| texts.append(response_only) | |
| return texts | |
| def load_teichai_dataset(dataset_id, provider, model_name, kwargs): | |
| """Load a single conversation-format dataset and return (texts, providers, models).""" | |
| max_samples = kwargs.get("max_samples") | |
| load_kwargs = {} | |
| if "name" in kwargs: | |
| load_kwargs["name"] = kwargs["name"] | |
| try: | |
| ds = load_dataset(dataset_id, split="train", token=HF_TOKEN, **load_kwargs) | |
| rows = list(ds) | |
| except Exception as e: | |
| # Fallback: load from auto-converted parquet via HF API | |
| try: | |
| import pandas as pd | |
| url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" | |
| df = pd.read_parquet(url) | |
| rows = df.to_dict(orient="records") | |
| except Exception as e2: | |
| print(f" [SKIP] {dataset_id}: {e} / parquet fallback: {e2}") | |
| return [], [], [] | |
| if max_samples and len(rows) > max_samples: | |
| import random | |
| random.seed(42) | |
| rows = random.sample(rows, max_samples) | |
| texts = _extract_assistant_texts_from_conversations(rows) | |
| # Filter out empty/too-short texts | |
| filtered = [(t, provider, model_name) for t in texts if len(t) > 50] | |
| if not filtered: | |
| print(f" [SKIP] {dataset_id}: no valid texts extracted") | |
| return [], [], [] | |
| t, p, m = zip(*filtered) | |
| return list(t), list(p), list(m) | |
| def load_am_deepseek_dataset(dataset_id, provider, model_name, kwargs): | |
| """Load a-m-team DeepSeek dataset.""" | |
| max_samples = kwargs.get("max_samples") | |
| load_kwargs = {} | |
| if "name" in kwargs: | |
| load_kwargs["name"] = kwargs["name"] | |
| try: | |
| ds = load_dataset(dataset_id, split="train", token=HF_TOKEN, **load_kwargs) | |
| except Exception: | |
| try: | |
| ds = load_dataset( | |
| dataset_id, split="train", streaming=True, token=HF_TOKEN, **load_kwargs | |
| ) | |
| rows = [] | |
| for row in ds: | |
| rows.append(row) | |
| if max_samples and len(rows) >= max_samples: | |
| break | |
| except Exception as e2: | |
| print(f" [SKIP] {dataset_id}: {e2}") | |
| return [], [], [] | |
| else: | |
| rows = list(ds) | |
| if max_samples and len(rows) > max_samples: | |
| rows = rows[:max_samples] | |
| texts = [] | |
| for row in rows: | |
| for text in _extract_from_am_dataset(row): | |
| if len(text) > 50: | |
| texts.append(text) | |
| providers = [provider] * len(texts) | |
| models = [model_name] * len(texts) | |
| return texts, providers, models | |
| def load_all_data(): | |
| """Load all datasets and return combined lists. | |
| Returns: | |
| texts: list of str | |
| providers: list of str | |
| models: list of str | |
| is_ai: list of int (1=AI, 0=Human) | |
| """ | |
| all_texts = [] | |
| all_providers = [] | |
| all_models = [] | |
| # TeichAI datasets | |
| print("Loading TeichAI datasets...") | |
| for dataset_id, provider, model_name, kwargs in tqdm( | |
| DATASET_REGISTRY, desc="TeichAI" | |
| ): | |
| t0 = time.time() | |
| texts, providers, models = load_teichai_dataset( | |
| dataset_id, provider, model_name, kwargs | |
| ) | |
| elapsed = time.time() - t0 | |
| all_texts.extend(texts) | |
| all_providers.extend(providers) | |
| all_models.extend(models) | |
| print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)") | |
| # DeepSeek a-m-team datasets | |
| print("\nLoading DeepSeek (a-m-team) datasets...") | |
| for dataset_id, provider, model_name, kwargs in tqdm( | |
| DEEPSEEK_AM_DATASETS, desc="DeepSeek-AM" | |
| ): | |
| t0 = time.time() | |
| texts, providers, models = load_am_deepseek_dataset( | |
| dataset_id, provider, model_name, kwargs | |
| ) | |
| elapsed = time.time() - t0 | |
| all_texts.extend(texts) | |
| all_providers.extend(providers) | |
| all_models.extend(models) | |
| print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)") | |
| # Deduplicate by text hash | |
| import hashlib | |
| import random as _rng | |
| _rng.seed(42) | |
| seen = set() | |
| dedup_texts, dedup_providers, dedup_models = [], [], [] | |
| for t, p, m in zip(all_texts, all_providers, all_models): | |
| h = hashlib.md5(t.strip().lower().encode()).hexdigest() | |
| if h not in seen: | |
| seen.add(h) | |
| dedup_texts.append(t) | |
| dedup_providers.append(p) | |
| dedup_models.append(m) | |
| n_dupes = len(all_texts) - len(dedup_texts) | |
| if n_dupes > 0: | |
| print(f"\n Removed {n_dupes} duplicate samples") | |
| # Equal samples per provider | |
| from collections import defaultdict | |
| provider_indices = defaultdict(list) | |
| for i, p in enumerate(dedup_providers): | |
| provider_indices[p].append(i) | |
| # Use min of available or max allowed | |
| keep_indices = [] | |
| for p, idxs in provider_indices.items(): | |
| _rng.shuffle(idxs) | |
| n_sample = min(len(idxs), MAX_SAMPLES_PER_PROVIDER) | |
| idxs = idxs[:n_sample] | |
| print(f" Sampled {p}: {len(idxs)} samples") | |
| keep_indices.extend(idxs) | |
| keep_indices.sort() | |
| all_texts = [dedup_texts[i] for i in keep_indices] | |
| all_providers = [dedup_providers[i] for i in keep_indices] | |
| all_models = [dedup_models[i] for i in keep_indices] | |
| # Build is_ai labels (all AI) | |
| is_ai = [1] * len(all_texts) | |
| print(f"\n=== Total: {len(all_texts)} samples ===") | |
| # Print per-provider counts | |
| from collections import Counter | |
| prov_counts = Counter(all_providers) | |
| for p, c in sorted(prov_counts.items(), key=lambda x: -x[1]): | |
| print(f" {p}: {c}") | |
| return all_texts, all_providers, all_models, is_ai | |
| if __name__ == "__main__": | |
| texts, providers, models, is_ai = load_all_data() | |