mtDNALocation / mtdna_backend.py
VyLala's picture
update new codes
7fc87fe verified
raw
history blame
10.4 kB
import gradio as gr
from collections import Counter
import csv
import os
from functools import lru_cache
from mtdna_classifier import classify_sample_location
import subprocess
import json
import pandas as pd
import io
import re
import tempfile
import gspread
from oauth2client.service_account import ServiceAccountCredentials
from io import StringIO
@lru_cache(maxsize=128)
def classify_sample_location_cached(accession):
return classify_sample_location(accession)
# Count and suggest final location
def compute_final_suggested_location(rows):
candidates = [
row.get("Predicted Location", "").strip()
for row in rows
if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
] + [
row.get("Inferred Region", "").strip()
for row in rows
if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
]
if not candidates:
return Counter(), ("Unknown", 0)
# Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
tokens = []
for item in candidates:
# Split by comma, whitespace, and newlines
parts = re.split(r'[\s,]+', item)
tokens.extend(parts)
# Step 2: Clean and normalize tokens
tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
# Step 3: Count
counts = Counter(tokens)
# Step 4: Get most common
top_location, count = counts.most_common(1)[0]
return counts, (top_location, count)
# Store feedback (with required fields)
def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
if not answer1.strip() or not answer2.strip():
return "⚠️ Please answer both questions before submitting."
try:
# ✅ Step: Load credentials from Hugging Face secret
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
# Connect to Google Sheet
client = gspread.authorize(creds)
sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
# Append feedback
sheet.append_row([accession, answer1, answer2, contact])
return "✅ Feedback submitted. Thank you!"
except Exception as e:
return f"❌ Error submitting feedback: {e}"
# helper function to extract accessions
def extract_accessions_from_input(file=None, raw_text=""):
print(f"RAW TEXT RECEIVED: {raw_text}")
accessions = []
seen = set()
if file:
try:
if file.name.endswith(".csv"):
df = pd.read_csv(file)
elif file.name.endswith(".xlsx"):
df = pd.read_excel(file)
else:
return [], "Unsupported file format. Please upload CSV or Excel."
for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
if acc not in seen:
accessions.append(acc)
seen.add(acc)
except Exception as e:
return [], f"Failed to read file: {e}"
if raw_text:
text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
for acc in text_ids:
if acc not in seen:
accessions.append(acc)
seen.add(acc)
return list(accessions), None
def summarize_results(accession):
try:
output, labelAncient_Modern, explain_label = classify_sample_location_cached(accession)
#print(output)
except Exception as e:
return [], f"Error: {e}", f"Error: {e}", f"Error: {e}"
if accession not in output:
return [], "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
isolate = next((k for k in output if k != accession), None)
row_score = []
rows = []
for key in [accession, isolate]:
if key not in output:
continue
sample_id_label = f"{key} ({'accession number' if key == accession else 'isolate of accession'})"
for section, techniques in output[key].items():
for technique, content in techniques.items():
source = content.get("source", "")
predicted = content.get("predicted_location", "")
haplogroup = content.get("haplogroup", "")
inferred = content.get("inferred_location", "")
context = content.get("context_snippet", "")[:300] if "context_snippet" in content else ""
row = {
"Sample ID": sample_id_label,
"Technique": technique,
"Source": f"The region of haplogroup is inferred\nby using this source: {source}" if technique == "haplogroup" else source,
"Predicted Location": "" if technique == "haplogroup" else predicted,
"Haplogroup": haplogroup if technique == "haplogroup" else "",
"Inferred Region": inferred if technique == "haplogroup" else "",
"Context Snippet": context
}
row_score.append(row)
rows.append(list(row.values()))
location_counts, (final_location, count) = compute_final_suggested_location(row_score)
summary_lines = [f"### 🧭 Location Frequency Summary", "After counting all predicted and inferred locations:\n"]
summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
summary_lines.append(f"\n**Final Suggested Location:** 🗺️ **{final_location}** (mentioned {count} times)")
summary = "\n".join(summary_lines)
return rows, summary, labelAncient_Modern, explain_label
# save the batch input in excel file
def save_to_excel(all_rows, summary_text, flag_text, filename):
with pd.ExcelWriter(filename) as writer:
# Save table
df = pd.DataFrame(all_rows, columns=["Sample ID", "Technique", "Source", "Predicted Location", "Haplogroup", "Inferred Region", "Context Snippet"])
df.to_excel(writer, sheet_name="Detailed Results", index=False)
# Save summary
summary_df = pd.DataFrame({"Summary": [summary_text]})
summary_df.to_excel(writer, sheet_name="Summary", index=False)
# Save flag
flag_df = pd.DataFrame({"Flag": [flag_text]})
flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
# save the batch input in JSON file
def save_to_json(all_rows, summary_text, flag_text, filename):
output_dict = {
"Detailed_Results": all_rows, # <-- make sure this is a plain list, not a DataFrame
"Summary_Text": summary_text,
"Ancient_Modern_Flag": flag_text
}
# If all_rows is a DataFrame, convert it
if isinstance(all_rows, pd.DataFrame):
output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
with open(filename, "w") as external_file:
json.dump(output_dict, external_file, indent=2)
# save the batch input in Text file
def save_to_txt(all_rows, summary_text, flag_text, filename):
if isinstance(all_rows, pd.DataFrame):
detailed_results = all_rows.to_dict(orient="records")
output = ""
output += ",".join(list(detailed_results[0].keys())) + "\n\n"
for r in detailed_results:
output += ",".join([str(v) for v in r.values()]) + "\n\n"
with open(filename, "w") as f:
f.write("=== Detailed Results ===\n")
f.write(output + "\n")
f.write("\n=== Summary ===\n")
f.write(summary_text + "\n")
f.write("\n=== Ancient/Modern Flag ===\n")
f.write(flag_text + "\n")
def save_batch_output(all_rows, summary_text, flag_text, output_type):
tmp_dir = tempfile.mkdtemp()
#html_table = all_rows.value # assuming this is stored somewhere
# Parse back to DataFrame
#all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
all_rows = pd.read_html(StringIO(all_rows))[0]
print(all_rows)
if output_type == "Excel":
file_path = f"{tmp_dir}/batch_output.xlsx"
save_to_excel(all_rows, summary_text, flag_text, file_path)
elif output_type == "JSON":
file_path = f"{tmp_dir}/batch_output.json"
save_to_json(all_rows, summary_text, flag_text, file_path)
print("Done with JSON")
elif output_type == "TXT":
file_path = f"{tmp_dir}/batch_output.txt"
save_to_txt(all_rows, summary_text, flag_text, file_path)
else:
return gr.update(visible=False) # invalid option
return gr.update(value=file_path, visible=True)
# run the batch
def summarize_batch(file=None, raw_text=""):
accessions, error = extract_accessions_from_input(file, raw_text)
if error:
return [], "", "", f"Error: {error}"
all_rows = []
all_summaries = []
all_flags = []
for acc in accessions:
try:
rows, summary, label, explain = summarize_results(acc)
all_rows.extend(rows)
all_summaries.append(f"**{acc}**\n{summary}")
all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
except Exception as e:
all_summaries.append(f"**{acc}**: Failed - {e}")
"""for row in all_rows:
source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
if source_column.startswith("http"): # Check if the source is a URL
# Wrap it with HTML anchor tags to make it clickable
row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
summary_text = "\n\n---\n\n".join(all_summaries)
flag_text = "\n\n---\n\n".join(all_flags)
return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)