mtDNALocation / mtdna_backend.py
VyLala's picture
Upload 8 files
6be4ec1 verified
raw
history blame
21.8 kB
import gradio as gr
from collections import Counter
import csv
import os
from functools import lru_cache
import mtdna_ui_app
from mtdna_classifier import classify_sample_location
from iterate3 import data_preprocess, model, pipeline
import subprocess
import json
import pandas as pd
import io
import re
import tempfile
import gspread
from oauth2client.service_account import ServiceAccountCredentials
from io import StringIO
import hashlib
import threading
# @lru_cache(maxsize=3600)
# def classify_sample_location_cached(accession):
# return classify_sample_location(accession)
@lru_cache(maxsize=3600)
def pipeline_classify_sample_location_cached(accession):
return pipeline.pipeline_with_gemini([accession])
# Count and suggest final location
# def compute_final_suggested_location(rows):
# candidates = [
# row.get("Predicted Location", "").strip()
# for row in rows
# if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
# ] + [
# row.get("Inferred Region", "").strip()
# for row in rows
# if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
# ]
# if not candidates:
# return Counter(), ("Unknown", 0)
# # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
# tokens = []
# for item in candidates:
# # Split by comma, whitespace, and newlines
# parts = re.split(r'[\s,]+', item)
# tokens.extend(parts)
# # Step 2: Clean and normalize tokens
# tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
# # Step 3: Count
# counts = Counter(tokens)
# # Step 4: Get most common
# top_location, count = counts.most_common(1)[0]
# return counts, (top_location, count)
# Store feedback (with required fields)
def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
if not answer1.strip() or not answer2.strip():
return "⚠️ Please answer both questions before submitting."
try:
# βœ… Step: Load credentials from Hugging Face secret
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
# Connect to Google Sheet
client = gspread.authorize(creds)
sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
# Append feedback
sheet.append_row([accession, answer1, answer2, contact])
return "βœ… Feedback submitted. Thank you!"
except Exception as e:
return f"❌ Error submitting feedback: {e}"
# helper function to extract accessions
def extract_accessions_from_input(file=None, raw_text=""):
print(f"RAW TEXT RECEIVED: {raw_text}")
accessions = []
seen = set()
if file:
try:
if file.name.endswith(".csv"):
df = pd.read_csv(file)
elif file.name.endswith(".xlsx"):
df = pd.read_excel(file)
else:
return [], "Unsupported file format. Please upload CSV or Excel."
for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
if acc not in seen:
accessions.append(acc)
seen.add(acc)
except Exception as e:
return [], f"Failed to read file: {e}"
if raw_text:
text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
for acc in text_ids:
if acc not in seen:
accessions.append(acc)
seen.add(acc)
return list(accessions), None
# βœ… Add a new helper to backend: `filter_unprocessed_accessions()`
def get_incomplete_accessions(file_path):
df = pd.read_excel(file_path)
incomplete_accessions = []
for _, row in df.iterrows():
sample_id = str(row.get("Sample ID", "")).strip()
# Skip if no sample ID
if not sample_id:
continue
# Drop the Sample ID and check if the rest is empty
other_cols = row.drop(labels=["Sample ID"], errors="ignore")
if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
# Extract the accession number from the sample ID using regex
match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
if match:
incomplete_accessions.append(match.group(0))
print(len(incomplete_accessions))
return incomplete_accessions
def summarize_results(accession, KNOWN_OUTPUT_PATH = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/iterate3/known_samples.xlsx"):
# try cache first
cached = check_known_output(accession)
if cached:
print(f"βœ… Using cached result for {accession}")
return [[
cached["Sample ID"],
cached["Predicted Country"],
cached["Country Explanation"],
cached["Predicted Sample Type"],
cached["Sample Type Explanation"],
cached["Sources"],
cached["Time cost"]
]]
# only run when nothing in the cache
try:
outputs = pipeline_classify_sample_location_cached(accession)
# outputs = {'KU131308': {'isolate':'BRU18',
# 'country': {'brunei': ['ncbi',
# 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
# 'sample_type': {'modern':
# ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
# 'query_cost': 9.754999999999999e-05,
# 'time_cost': '24.776 seconds',
# 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
# 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
# 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
except Exception as e:
return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
if accession not in outputs:
return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
row_score = []
rows = []
save_rows = []
for key in outputs:
pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
for section, results in outputs[key].items():
if section == "country" or section =="sample_type":
pred_output = "\n".join(list(results.keys()))
output_explanation = ""
for result, content in results.items():
if len(result) == 0: result = "unknown"
if len(content) == 0: output_explanation = "unknown"
else:
output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
if section == "country":
pred_country, country_explanation = pred_output, output_explanation
elif section == "sample_type":
pred_sample, sample_explanation = pred_output, output_explanation
if outputs[key]["isolate"].lower()!="unknown":
label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
else: label = key
if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
row = {
"Sample ID": label,
"Predicted Country": pred_country,
"Country Explanation": country_explanation,
"Predicted Sample Type":pred_sample,
"Sample Type Explanation":sample_explanation,
"Sources": "\n".join(outputs[key]["source"]),
"Time cost": outputs[key]["time_cost"]
}
#row_score.append(row)
rows.append(list(row.values()))
save_row = {
"Sample ID": label,
"Predicted Country": pred_country,
"Country Explanation": country_explanation,
"Predicted Sample Type":pred_sample,
"Sample Type Explanation":sample_explanation,
"Sources": "\n".join(outputs[key]["source"]),
"Query_cost": outputs[key]["query_cost"],
"Time cost": outputs[key]["time_cost"]
}
#row_score.append(row)
save_rows.append(list(save_row.values()))
# #location_counts, (final_location, count) = compute_final_suggested_location(row_score)
# summary_lines = [f"### 🧭 Location Summary:\n"]
# summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
# summary_lines.append(f"\n**Final Suggested Location:** πŸ—ΊοΈ **{final_location}** (mentioned {count} times)")
# summary = "\n".join(summary_lines)
# save the new running sample to known excel file
try:
df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"])
if os.path.exists(KNOWN_OUTPUT_PATH):
df_old = pd.read_excel(KNOWN_OUTPUT_PATH)
df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
else:
df_combined = df_new
df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False)
except Exception as e:
print(f"⚠️ Failed to save known output: {e}")
return rows#, summary, labelAncient_Modern, explain_label
# save the batch input in excel file
# def save_to_excel(all_rows, summary_text, flag_text, filename):
# with pd.ExcelWriter(filename) as writer:
# # Save table
# df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
# df.to_excel(writer, sheet_name="Detailed Results", index=False)
# try:
# df_old = pd.read_excel(filename)
# except:
# df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
# df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
# # if os.path.exists(filename):
# # df_old = pd.read_excel(filename)
# # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
# # else:
# # df_combined = df_new
# df_combined.to_excel(filename, index=False)
# # # Save summary
# # summary_df = pd.DataFrame({"Summary": [summary_text]})
# # summary_df.to_excel(writer, sheet_name="Summary", index=False)
# # # Save flag
# # flag_df = pd.DataFrame({"Flag": [flag_text]})
# # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
# def save_to_excel(all_rows, summary_text, flag_text, filename):
# df_new = pd.DataFrame(all_rows, columns=[
# "Sample ID", "Predicted Country", "Country Explanation",
# "Predicted Sample Type", "Sample Type Explanation",
# "Sources", "Time cost"
# ])
# try:
# if os.path.exists(filename):
# df_old = pd.read_excel(filename)
# else:
# df_old = pd.DataFrame(columns=df_new.columns)
# except Exception as e:
# print(f"⚠️ Warning reading old Excel file: {e}")
# df_old = pd.DataFrame(columns=df_new.columns)
# #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first")
# df_old.set_index("Sample ID", inplace=True)
# df_new.set_index("Sample ID", inplace=True)
# df_old.update(df_new) # <-- update matching rows in df_old with df_new content
# df_combined = df_old.reset_index()
# try:
# df_combined.to_excel(filename, index=False)
# except Exception as e:
# print(f"❌ Failed to write Excel file {filename}: {e}")
def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
df_new = pd.DataFrame(all_rows, columns=[
"Sample ID", "Predicted Country", "Country Explanation",
"Predicted Sample Type", "Sample Type Explanation",
"Sources", "Time cost"
])
if is_resume and os.path.exists(filename):
try:
df_old = pd.read_excel(filename)
except Exception as e:
print(f"⚠️ Warning reading old Excel file: {e}")
df_old = pd.DataFrame(columns=df_new.columns)
# Set index and update existing rows
df_old.set_index("Sample ID", inplace=True)
df_new.set_index("Sample ID", inplace=True)
df_old.update(df_new)
df_combined = df_old.reset_index()
else:
# If not resuming or file doesn't exist, just use new rows
df_combined = df_new
try:
df_combined.to_excel(filename, index=False)
except Exception as e:
print(f"❌ Failed to write Excel file {filename}: {e}")
# save the batch input in JSON file
def save_to_json(all_rows, summary_text, flag_text, filename):
output_dict = {
"Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
# "Summary_Text": summary_text,
# "Ancient_Modern_Flag": flag_text
}
# If all_rows is a DataFrame, convert it
if isinstance(all_rows, pd.DataFrame):
output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
with open(filename, "w") as external_file:
json.dump(output_dict, external_file, indent=2)
# save the batch input in Text file
def save_to_txt(all_rows, summary_text, flag_text, filename):
if isinstance(all_rows, pd.DataFrame):
detailed_results = all_rows.to_dict(orient="records")
output = ""
output += ",".join(list(detailed_results[0].keys())) + "\n\n"
for r in detailed_results:
output += ",".join([str(v) for v in r.values()]) + "\n\n"
with open(filename, "w") as f:
f.write("=== Detailed Results ===\n")
f.write(output + "\n")
# f.write("\n=== Summary ===\n")
# f.write(summary_text + "\n")
# f.write("\n=== Ancient/Modern Flag ===\n")
# f.write(flag_text + "\n")
def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
tmp_dir = tempfile.mkdtemp()
#html_table = all_rows.value # assuming this is stored somewhere
# Parse back to DataFrame
#all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
all_rows = pd.read_html(StringIO(all_rows))[0]
print(all_rows)
if output_type == "Excel":
file_path = f"{tmp_dir}/batch_output.xlsx"
save_to_excel(all_rows, summary_text, flag_text, file_path)
elif output_type == "JSON":
file_path = f"{tmp_dir}/batch_output.json"
save_to_json(all_rows, summary_text, flag_text, file_path)
print("Done with JSON")
elif output_type == "TXT":
file_path = f"{tmp_dir}/batch_output.txt"
save_to_txt(all_rows, summary_text, flag_text, file_path)
else:
return gr.update(visible=False) # invalid option
return gr.update(value=file_path, visible=True)
# save cost by checking the known outputs
def check_known_output(accession, KNOWN_OUTPUT_PATH = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/iterate3/known_samples.xlsx"):
if not os.path.exists(KNOWN_OUTPUT_PATH):
return None
try:
df = pd.read_excel(KNOWN_OUTPUT_PATH)
match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
if match:
accession = match.group(0)
matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
if not matched.empty:
return matched.iloc[0].to_dict() # Return the cached row
except Exception as e:
print(f"⚠️ Failed to load known samples: {e}")
return None
USER_USAGE_TRACK_FILE = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/iterate3/user_usage_log.json"
def hash_user_id(user_input):
return hashlib.sha256(user_input.encode()).hexdigest()
# βœ… Load and save usage count
# def load_user_usage():
# if os.path.exists(USER_USAGE_TRACK_FILE):
# with open(USER_USAGE_TRACK_FILE, "r") as f:
# return json.load(f)
# return {}
def load_user_usage():
if not os.path.exists(USER_USAGE_TRACK_FILE):
return {}
try:
with open(USER_USAGE_TRACK_FILE, "r") as f:
content = f.read().strip()
if not content:
return {} # file is empty
return json.loads(content)
except (json.JSONDecodeError, ValueError):
print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
return {} # fallback to empty dict
def save_user_usage(usage):
with open(USER_USAGE_TRACK_FILE, "w") as f:
json.dump(usage, f, indent=2)
# def increment_usage(user_id, num_samples=1):
# usage = load_user_usage()
# if user_id not in usage:
# usage[user_id] = 0
# usage[user_id] += num_samples
# save_user_usage(usage)
# return usage[user_id]
def increment_usage(email: str, count: int):
usage = load_user_usage()
email_key = email.strip().lower()
usage[email_key] = usage.get(email_key, 0) + count
save_user_usage(usage)
return usage[email_key]
# run the batch
def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
stop_flag=None, output_file_path=None,
limited_acc=50, yield_callback=None):
if user_email:
limited_acc += 10
accessions, error = extract_accessions_from_input(file, raw_text)
if error:
#return [], "", "", f"Error: {error}"
return [], f"Error: {error}", 0, "", ""
if resume_file:
accessions = get_incomplete_accessions(resume_file)
tmp_dir = tempfile.mkdtemp()
if not output_file_path:
if resume_file:
output_file_path = os.path.join(tmp_dir, resume_file)
else:
output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
all_rows = []
# all_summaries = []
# all_flags = []
progress_lines = []
warning = ""
if len(accessions) > limited_acc:
accessions = accessions[:limited_acc]
warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
for i, acc in enumerate(accessions):
if stop_flag and stop_flag.value:
line = f"πŸ›‘ Stopped at {acc} ({i+1}/{len(accessions)})"
progress_lines.append(line)
if yield_callback:
yield_callback(line)
print("πŸ›‘ User requested stop.")
break
print(f"[{i+1}/{len(accessions)}] Processing {acc}")
try:
# rows, summary, label, explain = summarize_results(acc)
rows = summarize_results(acc)
all_rows.extend(rows)
# all_summaries.append(f"**{acc}**\n{summary}")
# all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
#save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
line = f"βœ… Processed {acc} ({i+1}/{len(accessions)})"
progress_lines.append(line)
if yield_callback:
yield_callback(f"βœ… Processed {acc} ({i+1}/{len(accessions)})")
except Exception as e:
print(f"❌ Failed to process {acc}: {e}")
continue
#all_summaries.append(f"**{acc}**: Failed - {e}")
#progress_lines.append(f"βœ… Processed {acc} ({i+1}/{len(accessions)})")
limited_acc -= 1
"""for row in all_rows:
source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
if source_column.startswith("http"): # Check if the source is a URL
# Wrap it with HTML anchor tags to make it clickable
row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
if not warning:
warning = f"You only have {limited_acc} left"
if user_email.strip():
user_hash = hash_user_id(user_email)
total_queries = increment_usage(user_hash, len(all_rows))
else:
total_queries = 0
yield_callback("βœ… Finished!")
# summary_text = "\n\n---\n\n".join(all_summaries)
# flag_text = "\n\n---\n\n".join(all_flags)
#return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
#return all_rows, gr.update(visible=True), gr.update(visible=False)
return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning