import gradio as gr from collections import Counter import csv import os from functools import lru_cache import mtdna_ui_app from mtdna_classifier import classify_sample_location from iterate3 import data_preprocess, model, pipeline import subprocess import json import pandas as pd import io import re import tempfile import gspread from oauth2client.service_account import ServiceAccountCredentials from io import StringIO import hashlib import threading # @lru_cache(maxsize=3600) # def classify_sample_location_cached(accession): # return classify_sample_location(accession) @lru_cache(maxsize=3600) def pipeline_classify_sample_location_cached(accession): return pipeline.pipeline_with_gemini([accession]) # Count and suggest final location # def compute_final_suggested_location(rows): # candidates = [ # row.get("Predicted Location", "").strip() # for row in rows # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"] # ] + [ # row.get("Inferred Region", "").strip() # for row in rows # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"] # ] # if not candidates: # return Counter(), ("Unknown", 0) # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc. # tokens = [] # for item in candidates: # # Split by comma, whitespace, and newlines # parts = re.split(r'[\s,]+', item) # tokens.extend(parts) # # Step 2: Clean and normalize tokens # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens # # Step 3: Count # counts = Counter(tokens) # # Step 4: Get most common # top_location, count = counts.most_common(1)[0] # return counts, (top_location, count) # Store feedback (with required fields) def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""): if not answer1.strip() or not answer2.strip(): return "⚠️ Please answer both questions before submitting." try: # ✅ Step: Load credentials from Hugging Face secret creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"] creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) # Connect to Google Sheet client = gspread.authorize(creds) sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches # Append feedback sheet.append_row([accession, answer1, answer2, contact]) return "✅ Feedback submitted. Thank you!" except Exception as e: return f"❌ Error submitting feedback: {e}" # helper function to extract accessions def extract_accessions_from_input(file=None, raw_text=""): print(f"RAW TEXT RECEIVED: {raw_text}") accessions = [] seen = set() if file: try: if file.name.endswith(".csv"): df = pd.read_csv(file) elif file.name.endswith(".xlsx"): df = pd.read_excel(file) else: return [], "Unsupported file format. Please upload CSV or Excel." for acc in df.iloc[:, 0].dropna().astype(str).str.strip(): if acc not in seen: accessions.append(acc) seen.add(acc) except Exception as e: return [], f"Failed to read file: {e}" if raw_text: text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()] for acc in text_ids: if acc not in seen: accessions.append(acc) seen.add(acc) return list(accessions), None # ✅ Add a new helper to backend: `filter_unprocessed_accessions()` def get_incomplete_accessions(file_path): df = pd.read_excel(file_path) incomplete_accessions = [] for _, row in df.iterrows(): sample_id = str(row.get("Sample ID", "")).strip() # Skip if no sample ID if not sample_id: continue # Drop the Sample ID and check if the rest is empty other_cols = row.drop(labels=["Sample ID"], errors="ignore") if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all(): # Extract the accession number from the sample ID using regex match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id) if match: incomplete_accessions.append(match.group(0)) print(len(incomplete_accessions)) return incomplete_accessions def summarize_results(accession, KNOWN_OUTPUT_PATH = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/iterate3/known_samples.xlsx"): # try cache first cached = check_known_output(accession) if cached: print(f"✅ Using cached result for {accession}") return [[ cached["Sample ID"], cached["Predicted Country"], cached["Country Explanation"], cached["Predicted Sample Type"], cached["Sample Type Explanation"], cached["Sources"], cached["Time cost"] ]] # only run when nothing in the cache try: outputs = pipeline_classify_sample_location_cached(accession) # outputs = {'KU131308': {'isolate':'BRU18', # 'country': {'brunei': ['ncbi', # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']}, # 'sample_type': {'modern': # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']}, # 'query_cost': 9.754999999999999e-05, # 'time_cost': '24.776 seconds', # 'source': ['https://doi.org/10.1007/s00439-015-1620-z', # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf', # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}} except Exception as e: return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}" if accession not in outputs: return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results." row_score = [] rows = [] save_rows = [] for key in outputs: pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown" for section, results in outputs[key].items(): if section == "country" or section =="sample_type": pred_output = "\n".join(list(results.keys())) output_explanation = "" for result, content in results.items(): if len(result) == 0: result = "unknown" if len(content) == 0: output_explanation = "unknown" else: output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n" if section == "country": pred_country, country_explanation = pred_output, output_explanation elif section == "sample_type": pred_sample, sample_explanation = pred_output, output_explanation if outputs[key]["isolate"].lower()!="unknown": label = key + "(Isolate: " + outputs[key]["isolate"] + ")" else: label = key if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"] row = { "Sample ID": label, "Predicted Country": pred_country, "Country Explanation": country_explanation, "Predicted Sample Type":pred_sample, "Sample Type Explanation":sample_explanation, "Sources": "\n".join(outputs[key]["source"]), "Time cost": outputs[key]["time_cost"] } #row_score.append(row) rows.append(list(row.values())) save_row = { "Sample ID": label, "Predicted Country": pred_country, "Country Explanation": country_explanation, "Predicted Sample Type":pred_sample, "Sample Type Explanation":sample_explanation, "Sources": "\n".join(outputs[key]["source"]), "Query_cost": outputs[key]["query_cost"], "Time cost": outputs[key]["time_cost"] } #row_score.append(row) save_rows.append(list(save_row.values())) # #location_counts, (final_location, count) = compute_final_suggested_location(row_score) # summary_lines = [f"### 🧭 Location Summary:\n"] # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()] # summary_lines.append(f"\n**Final Suggested Location:** 🗺️ **{final_location}** (mentioned {count} times)") # summary = "\n".join(summary_lines) # save the new running sample to known excel file try: df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"]) if os.path.exists(KNOWN_OUTPUT_PATH): df_old = pd.read_excel(KNOWN_OUTPUT_PATH) df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID") else: df_combined = df_new df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False) except Exception as e: print(f"⚠️ Failed to save known output: {e}") return rows#, summary, labelAncient_Modern, explain_label # save the batch input in excel file # def save_to_excel(all_rows, summary_text, flag_text, filename): # with pd.ExcelWriter(filename) as writer: # # Save table # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"]) # df.to_excel(writer, sheet_name="Detailed Results", index=False) # try: # df_old = pd.read_excel(filename) # except: # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"]) # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID") # # if os.path.exists(filename): # # df_old = pd.read_excel(filename) # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID") # # else: # # df_combined = df_new # df_combined.to_excel(filename, index=False) # # # Save summary # # summary_df = pd.DataFrame({"Summary": [summary_text]}) # # summary_df.to_excel(writer, sheet_name="Summary", index=False) # # # Save flag # # flag_df = pd.DataFrame({"Flag": [flag_text]}) # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False) # def save_to_excel(all_rows, summary_text, flag_text, filename): # df_new = pd.DataFrame(all_rows, columns=[ # "Sample ID", "Predicted Country", "Country Explanation", # "Predicted Sample Type", "Sample Type Explanation", # "Sources", "Time cost" # ]) # try: # if os.path.exists(filename): # df_old = pd.read_excel(filename) # else: # df_old = pd.DataFrame(columns=df_new.columns) # except Exception as e: # print(f"⚠️ Warning reading old Excel file: {e}") # df_old = pd.DataFrame(columns=df_new.columns) # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first") # df_old.set_index("Sample ID", inplace=True) # df_new.set_index("Sample ID", inplace=True) # df_old.update(df_new) # <-- update matching rows in df_old with df_new content # df_combined = df_old.reset_index() # try: # df_combined.to_excel(filename, index=False) # except Exception as e: # print(f"❌ Failed to write Excel file {filename}: {e}") def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False): df_new = pd.DataFrame(all_rows, columns=[ "Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost" ]) if is_resume and os.path.exists(filename): try: df_old = pd.read_excel(filename) except Exception as e: print(f"⚠️ Warning reading old Excel file: {e}") df_old = pd.DataFrame(columns=df_new.columns) # Set index and update existing rows df_old.set_index("Sample ID", inplace=True) df_new.set_index("Sample ID", inplace=True) df_old.update(df_new) df_combined = df_old.reset_index() else: # If not resuming or file doesn't exist, just use new rows df_combined = df_new try: df_combined.to_excel(filename, index=False) except Exception as e: print(f"❌ Failed to write Excel file {filename}: {e}") # save the batch input in JSON file def save_to_json(all_rows, summary_text, flag_text, filename): output_dict = { "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame # "Summary_Text": summary_text, # "Ancient_Modern_Flag": flag_text } # If all_rows is a DataFrame, convert it if isinstance(all_rows, pd.DataFrame): output_dict["Detailed_Results"] = all_rows.to_dict(orient="records") with open(filename, "w") as external_file: json.dump(output_dict, external_file, indent=2) # save the batch input in Text file def save_to_txt(all_rows, summary_text, flag_text, filename): if isinstance(all_rows, pd.DataFrame): detailed_results = all_rows.to_dict(orient="records") output = "" output += ",".join(list(detailed_results[0].keys())) + "\n\n" for r in detailed_results: output += ",".join([str(v) for v in r.values()]) + "\n\n" with open(filename, "w") as f: f.write("=== Detailed Results ===\n") f.write(output + "\n") # f.write("\n=== Summary ===\n") # f.write(summary_text + "\n") # f.write("\n=== Ancient/Modern Flag ===\n") # f.write(flag_text + "\n") def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None): tmp_dir = tempfile.mkdtemp() #html_table = all_rows.value # assuming this is stored somewhere # Parse back to DataFrame #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list all_rows = pd.read_html(StringIO(all_rows))[0] print(all_rows) if output_type == "Excel": file_path = f"{tmp_dir}/batch_output.xlsx" save_to_excel(all_rows, summary_text, flag_text, file_path) elif output_type == "JSON": file_path = f"{tmp_dir}/batch_output.json" save_to_json(all_rows, summary_text, flag_text, file_path) print("Done with JSON") elif output_type == "TXT": file_path = f"{tmp_dir}/batch_output.txt" save_to_txt(all_rows, summary_text, flag_text, file_path) else: return gr.update(visible=False) # invalid option return gr.update(value=file_path, visible=True) # save cost by checking the known outputs def check_known_output(accession, KNOWN_OUTPUT_PATH = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/iterate3/known_samples.xlsx"): if not os.path.exists(KNOWN_OUTPUT_PATH): return None try: df = pd.read_excel(KNOWN_OUTPUT_PATH) match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession) if match: accession = match.group(0) matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)] if not matched.empty: return matched.iloc[0].to_dict() # Return the cached row except Exception as e: print(f"⚠️ Failed to load known samples: {e}") return None USER_USAGE_TRACK_FILE = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/iterate3/user_usage_log.json" def hash_user_id(user_input): return hashlib.sha256(user_input.encode()).hexdigest() # ✅ Load and save usage count # def load_user_usage(): # if os.path.exists(USER_USAGE_TRACK_FILE): # with open(USER_USAGE_TRACK_FILE, "r") as f: # return json.load(f) # return {} def load_user_usage(): if not os.path.exists(USER_USAGE_TRACK_FILE): return {} try: with open(USER_USAGE_TRACK_FILE, "r") as f: content = f.read().strip() if not content: return {} # file is empty return json.loads(content) except (json.JSONDecodeError, ValueError): print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.") return {} # fallback to empty dict def save_user_usage(usage): with open(USER_USAGE_TRACK_FILE, "w") as f: json.dump(usage, f, indent=2) # def increment_usage(user_id, num_samples=1): # usage = load_user_usage() # if user_id not in usage: # usage[user_id] = 0 # usage[user_id] += num_samples # save_user_usage(usage) # return usage[user_id] def increment_usage(email: str, count: int): usage = load_user_usage() email_key = email.strip().lower() usage[email_key] = usage.get(email_key, 0) + count save_user_usage(usage) return usage[email_key] # run the batch def summarize_batch(file=None, raw_text="", resume_file=None, user_email="", stop_flag=None, output_file_path=None, limited_acc=50, yield_callback=None): if user_email: limited_acc += 10 accessions, error = extract_accessions_from_input(file, raw_text) if error: #return [], "", "", f"Error: {error}" return [], f"Error: {error}", 0, "", "" if resume_file: accessions = get_incomplete_accessions(resume_file) tmp_dir = tempfile.mkdtemp() if not output_file_path: if resume_file: output_file_path = os.path.join(tmp_dir, resume_file) else: output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx") all_rows = [] # all_summaries = [] # all_flags = [] progress_lines = [] warning = "" if len(accessions) > limited_acc: accessions = accessions[:limited_acc] warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions" for i, acc in enumerate(accessions): if stop_flag and stop_flag.value: line = f"🛑 Stopped at {acc} ({i+1}/{len(accessions)})" progress_lines.append(line) if yield_callback: yield_callback(line) print("🛑 User requested stop.") break print(f"[{i+1}/{len(accessions)}] Processing {acc}") try: # rows, summary, label, explain = summarize_results(acc) rows = summarize_results(acc) all_rows.extend(rows) # all_summaries.append(f"**{acc}**\n{summary}") # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}") #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path) save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file)) line = f"✅ Processed {acc} ({i+1}/{len(accessions)})" progress_lines.append(line) if yield_callback: yield_callback(f"✅ Processed {acc} ({i+1}/{len(accessions)})") except Exception as e: print(f"❌ Failed to process {acc}: {e}") continue #all_summaries.append(f"**{acc}**: Failed - {e}") #progress_lines.append(f"✅ Processed {acc} ({i+1}/{len(accessions)})") limited_acc -= 1 """for row in all_rows: source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2) if source_column.startswith("http"): # Check if the source is a URL # Wrap it with HTML anchor tags to make it clickable row[2] = f'{source_column}'""" if not warning: warning = f"You only have {limited_acc} left" if user_email.strip(): user_hash = hash_user_id(user_email) total_queries = increment_usage(user_hash, len(all_rows)) else: total_queries = 0 yield_callback("✅ Finished!") # summary_text = "\n\n---\n\n".join(all_summaries) # flag_text = "\n\n---\n\n".join(all_flags) #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False) #return all_rows, gr.update(visible=True), gr.update(visible=False) return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning