Spaces:
Running
Running
File size: 10,354 Bytes
7fc87fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
import gradio as gr
from collections import Counter
import csv
import os
from functools import lru_cache
from mtdna_classifier import classify_sample_location
import subprocess
import json
import pandas as pd
import io
import re
import tempfile
import gspread
from oauth2client.service_account import ServiceAccountCredentials
from io import StringIO
@lru_cache(maxsize=128)
def classify_sample_location_cached(accession):
return classify_sample_location(accession)
# Count and suggest final location
def compute_final_suggested_location(rows):
candidates = [
row.get("Predicted Location", "").strip()
for row in rows
if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
] + [
row.get("Inferred Region", "").strip()
for row in rows
if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
]
if not candidates:
return Counter(), ("Unknown", 0)
# Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
tokens = []
for item in candidates:
# Split by comma, whitespace, and newlines
parts = re.split(r'[\s,]+', item)
tokens.extend(parts)
# Step 2: Clean and normalize tokens
tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
# Step 3: Count
counts = Counter(tokens)
# Step 4: Get most common
top_location, count = counts.most_common(1)[0]
return counts, (top_location, count)
# Store feedback (with required fields)
def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
if not answer1.strip() or not answer2.strip():
return "⚠️ Please answer both questions before submitting."
try:
# ✅ Step: Load credentials from Hugging Face secret
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
# Connect to Google Sheet
client = gspread.authorize(creds)
sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
# Append feedback
sheet.append_row([accession, answer1, answer2, contact])
return "✅ Feedback submitted. Thank you!"
except Exception as e:
return f"❌ Error submitting feedback: {e}"
# helper function to extract accessions
def extract_accessions_from_input(file=None, raw_text=""):
print(f"RAW TEXT RECEIVED: {raw_text}")
accessions = []
seen = set()
if file:
try:
if file.name.endswith(".csv"):
df = pd.read_csv(file)
elif file.name.endswith(".xlsx"):
df = pd.read_excel(file)
else:
return [], "Unsupported file format. Please upload CSV or Excel."
for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
if acc not in seen:
accessions.append(acc)
seen.add(acc)
except Exception as e:
return [], f"Failed to read file: {e}"
if raw_text:
text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
for acc in text_ids:
if acc not in seen:
accessions.append(acc)
seen.add(acc)
return list(accessions), None
def summarize_results(accession):
try:
output, labelAncient_Modern, explain_label = classify_sample_location_cached(accession)
#print(output)
except Exception as e:
return [], f"Error: {e}", f"Error: {e}", f"Error: {e}"
if accession not in output:
return [], "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
isolate = next((k for k in output if k != accession), None)
row_score = []
rows = []
for key in [accession, isolate]:
if key not in output:
continue
sample_id_label = f"{key} ({'accession number' if key == accession else 'isolate of accession'})"
for section, techniques in output[key].items():
for technique, content in techniques.items():
source = content.get("source", "")
predicted = content.get("predicted_location", "")
haplogroup = content.get("haplogroup", "")
inferred = content.get("inferred_location", "")
context = content.get("context_snippet", "")[:300] if "context_snippet" in content else ""
row = {
"Sample ID": sample_id_label,
"Technique": technique,
"Source": f"The region of haplogroup is inferred\nby using this source: {source}" if technique == "haplogroup" else source,
"Predicted Location": "" if technique == "haplogroup" else predicted,
"Haplogroup": haplogroup if technique == "haplogroup" else "",
"Inferred Region": inferred if technique == "haplogroup" else "",
"Context Snippet": context
}
row_score.append(row)
rows.append(list(row.values()))
location_counts, (final_location, count) = compute_final_suggested_location(row_score)
summary_lines = [f"### 🧭 Location Frequency Summary", "After counting all predicted and inferred locations:\n"]
summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
summary_lines.append(f"\n**Final Suggested Location:** 🗺️ **{final_location}** (mentioned {count} times)")
summary = "\n".join(summary_lines)
return rows, summary, labelAncient_Modern, explain_label
# save the batch input in excel file
def save_to_excel(all_rows, summary_text, flag_text, filename):
with pd.ExcelWriter(filename) as writer:
# Save table
df = pd.DataFrame(all_rows, columns=["Sample ID", "Technique", "Source", "Predicted Location", "Haplogroup", "Inferred Region", "Context Snippet"])
df.to_excel(writer, sheet_name="Detailed Results", index=False)
# Save summary
summary_df = pd.DataFrame({"Summary": [summary_text]})
summary_df.to_excel(writer, sheet_name="Summary", index=False)
# Save flag
flag_df = pd.DataFrame({"Flag": [flag_text]})
flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
# save the batch input in JSON file
def save_to_json(all_rows, summary_text, flag_text, filename):
output_dict = {
"Detailed_Results": all_rows, # <-- make sure this is a plain list, not a DataFrame
"Summary_Text": summary_text,
"Ancient_Modern_Flag": flag_text
}
# If all_rows is a DataFrame, convert it
if isinstance(all_rows, pd.DataFrame):
output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
with open(filename, "w") as external_file:
json.dump(output_dict, external_file, indent=2)
# save the batch input in Text file
def save_to_txt(all_rows, summary_text, flag_text, filename):
if isinstance(all_rows, pd.DataFrame):
detailed_results = all_rows.to_dict(orient="records")
output = ""
output += ",".join(list(detailed_results[0].keys())) + "\n\n"
for r in detailed_results:
output += ",".join([str(v) for v in r.values()]) + "\n\n"
with open(filename, "w") as f:
f.write("=== Detailed Results ===\n")
f.write(output + "\n")
f.write("\n=== Summary ===\n")
f.write(summary_text + "\n")
f.write("\n=== Ancient/Modern Flag ===\n")
f.write(flag_text + "\n")
def save_batch_output(all_rows, summary_text, flag_text, output_type):
tmp_dir = tempfile.mkdtemp()
#html_table = all_rows.value # assuming this is stored somewhere
# Parse back to DataFrame
#all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
all_rows = pd.read_html(StringIO(all_rows))[0]
print(all_rows)
if output_type == "Excel":
file_path = f"{tmp_dir}/batch_output.xlsx"
save_to_excel(all_rows, summary_text, flag_text, file_path)
elif output_type == "JSON":
file_path = f"{tmp_dir}/batch_output.json"
save_to_json(all_rows, summary_text, flag_text, file_path)
print("Done with JSON")
elif output_type == "TXT":
file_path = f"{tmp_dir}/batch_output.txt"
save_to_txt(all_rows, summary_text, flag_text, file_path)
else:
return gr.update(visible=False) # invalid option
return gr.update(value=file_path, visible=True)
# run the batch
def summarize_batch(file=None, raw_text=""):
accessions, error = extract_accessions_from_input(file, raw_text)
if error:
return [], "", "", f"Error: {error}"
all_rows = []
all_summaries = []
all_flags = []
for acc in accessions:
try:
rows, summary, label, explain = summarize_results(acc)
all_rows.extend(rows)
all_summaries.append(f"**{acc}**\n{summary}")
all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
except Exception as e:
all_summaries.append(f"**{acc}**: Failed - {e}")
"""for row in all_rows:
source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
if source_column.startswith("http"): # Check if the source is a URL
# Wrap it with HTML anchor tags to make it clickable
row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
summary_text = "\n\n---\n\n".join(all_summaries)
flag_text = "\n\n---\n\n".join(all_flags)
return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False) |