Spaces:
Running
Running
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Any, Dict, List | |
| def load_img_to_array(img_p): | |
| img = Image.open(img_p) | |
| if img.mode == "RGBA": | |
| img = img.convert("RGB") | |
| return np.array(img) | |
| def save_array_to_img(img_arr, img_p): | |
| Image.fromarray(img_arr.astype(np.uint8)).save(img_p) | |
| def dilate_mask(mask, dilate_factor=15): | |
| mask = mask.astype(np.uint8) | |
| mask = cv2.dilate( | |
| mask, | |
| np.ones((dilate_factor, dilate_factor), np.uint8), | |
| iterations=1 | |
| ) | |
| return mask | |
| def erode_mask(mask, dilate_factor=15): | |
| mask = mask.astype(np.uint8) | |
| mask = cv2.erode( | |
| mask, | |
| np.ones((dilate_factor, dilate_factor), np.uint8), | |
| iterations=1 | |
| ) | |
| return mask | |
| def show_mask(ax, mask: np.ndarray, random_color=False): | |
| mask = mask.astype(np.uint8) | |
| if np.max(mask) == 255: | |
| mask = mask / 255 | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
| else: | |
| color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) | |
| h, w = mask.shape[-2:] | |
| mask_img = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
| ax.imshow(mask_img) | |
| def show_points(ax, coords: List[List[float]], labels: List[int], size=375): | |
| coords = np.array(coords) | |
| labels = np.array(labels) | |
| color_table = {0: 'red', 1: 'green'} | |
| for label_value, color in color_table.items(): | |
| points = coords[labels == label_value] | |
| ax.scatter(points[:, 0], points[:, 1], color=color, marker='*', | |
| s=size, edgecolor='white', linewidth=1.25) | |
| def get_clicked_point(img_path): | |
| img = cv2.imread(img_path) | |
| cv2.namedWindow("image") | |
| cv2.imshow("image", img) | |
| last_point = [] | |
| keep_looping = True | |
| def mouse_callback(event, x, y, flags, param): | |
| nonlocal last_point, keep_looping, img | |
| if event == cv2.EVENT_LBUTTONDOWN: | |
| if last_point: | |
| cv2.circle(img, tuple(last_point), 5, (0, 0, 0), -1) | |
| last_point = [x, y] | |
| cv2.circle(img, tuple(last_point), 5, (0, 0, 255), -1) | |
| cv2.imshow("image", img) | |
| elif event == cv2.EVENT_RBUTTONDOWN: | |
| keep_looping = False | |
| cv2.setMouseCallback("image", mouse_callback) | |
| while keep_looping: | |
| cv2.waitKey(1) | |
| cv2.destroyAllWindows() | |
| return last_point |