MicroscopyMatching / inference_track.py
VisionLanguageGroup's picture
cleanup
4ce5a27
# inference_track.py
import torch
import numpy as np
import os
from pathlib import Path
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from tracking_one import TrackingModule
from models.tra_post_model.tracking import graph_to_ctc
MODEL = None
DEVICE = torch.device("cpu")
def load_model(use_box=False):
"""
load tracking model from Hugging Face Hub
Args:
use_box: use bounding box as input (default: False)
Returns:
model: loaded tracking model
device
"""
global MODEL, DEVICE
try:
print("๐Ÿ”„ Loading tracking model...")
# ๅˆๅง‹ๅŒ–ๆจกๅž‹
MODEL = TrackingModule(use_box=use_box)
# Load checkpoint from Hugging Face Hub
ckpt_path = hf_hub_download(
repo_id="phoebe777777/111",
filename="microscopy_matching_tra.pth",
token=None,
force_download=False
)
print(f"โœ… Checkpoint downloaded: {ckpt_path}")
# Load weights
MODEL.load_state_dict(
torch.load(ckpt_path, map_location="cpu"),
strict=True
)
MODEL.eval()
# Move model to device
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
MODEL.move_to_device(DEVICE)
print("โœ… Model moved to CUDA")
else:
DEVICE = torch.device("cpu")
MODEL.move_to_device(DEVICE)
print("โœ… Model on CPU")
print("โœ… Tracking model loaded successfully")
return MODEL, DEVICE
except Exception as e:
print(f"โŒ Error loading tracking model: {e}")
import traceback
traceback.print_exc()
return None, torch.device("cpu")
@torch.no_grad()
def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
"""
run tracking inference on video frames
Args:
model: loaded tracking model
video_dir: directory of video frame sequence (contains consecutive image files)
box: bounding box (optional)
device: device
output_dir: output directory
Returns:
result_dict: {
'track_graph': TrackGraph object containing tracking results,
'masks': tracked masks (T, H, W),
'output_dir': output directory path,
'num_tracks': number of tracked trajectories
}
"""
if model is None:
return {
'track_graph': None,
'masks': None,
'output_dir': None,
'num_tracks': 0,
'error': 'Model not loaded'
}
try:
print(f"๐Ÿ”„ Running tracking inference on {video_dir}")
# Run tracking
track_graph, masks = model.track(
file_dir=video_dir,
boxes=box,
mode="greedy", # Optional: "greedy", "greedy_nodiv", "ilp"
dataname="tracking_result"
)
# ๅˆ›ๅปบ่พ“ๅ‡บ็›ฎๅฝ•
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Convert tracking results to CTC format and save
print("๐Ÿ”„ Converting to CTC format...")
ctc_tracks, masks_tracked = graph_to_ctc(
track_graph,
masks,
outdir=output_dir,
)
print(f"โœ… CTC results saved to {output_dir}")
print(f"โœ… Tracking completed")
result = {
'track_graph': track_graph,
'masks': masks,
'masks_tracked': masks_tracked,
'output_dir': output_dir,
}
return result
except Exception as e:
print(f"โŒ Tracking inference error: {e}")
import traceback
traceback.print_exc()
return {
'track_graph': None,
'masks': None,
'output_dir': None,
'num_tracks': 0,
'error': str(e)
}
def visualize_tracking_result(masks_tracked, output_path):
"""
visualize tracking results
Args:
masks_tracked: masks with tracking results (T, H, W)
output_path: output video file path
Returns:
output_path: output video file path
"""
try:
import cv2
import matplotlib.pyplot as plt
from matplotlib import cm
T, H, W = masks_tracked.shape
# create a color map for unique track IDs
unique_ids = np.unique(masks_tracked)
num_colors = len(unique_ids)
cmap = cm.get_cmap('tab20', num_colors)
# create video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, 5.0, (W, H))
for t in range(T):
frame = masks_tracked[t]
# create colored image
colored_frame = np.zeros((H, W, 3), dtype=np.uint8)
for i, obj_id in enumerate(unique_ids):
if obj_id == 0:
continue
mask = (frame == obj_id)
color = np.array(cmap(i % num_colors)[:3]) * 255
colored_frame[mask] = color
# convert to BGR (OpenCV format)
colored_frame_bgr = cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR)
out.write(colored_frame_bgr)
out.release()
print(f"โœ… Visualization saved to {output_path}")
return output_path
except Exception as e:
print(f"โŒ Visualization error: {e}")
return None