Spaces:
Sleeping
Sleeping
| from .model import FastSAM | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Optional, List, Tuple, Union | |
| class FastSAMDecoder: | |
| def __init__( | |
| self, | |
| model: FastSAM, | |
| device: str='cpu', | |
| conf: float=0.4, | |
| iou: float=0.9, | |
| imgsz: int=1024, | |
| retina_masks: bool=True, | |
| ): | |
| self.model = model | |
| self.device = device | |
| self.retina_masks = retina_masks | |
| self.imgsz = imgsz | |
| self.conf = conf | |
| self.iou = iou | |
| self.image = None | |
| self.image_embedding = None | |
| def run_encoder(self, image): | |
| if isinstance(image,str): | |
| image = np.array(Image.open(image)) | |
| self.image = image | |
| image_embedding = self.model( | |
| self.image, | |
| device=self.device, | |
| retina_masks=self.retina_masks, | |
| imgsz=self.imgsz, | |
| conf=self.conf, | |
| iou=self.iou | |
| ) | |
| return image_embedding[0].numpy() | |
| def run_decoder( | |
| self, | |
| image_embedding, | |
| point_prompt: Optional[np.ndarray]=None, | |
| point_label: Optional[np.ndarray]=None, | |
| box_prompt: Optional[np.ndarray]=None, | |
| text_prompt: Optional[str]=None, | |
| )->np.ndarray: | |
| self.image_embedding = image_embedding | |
| if point_prompt is not None: | |
| ann = self.point_prompt(points=point_prompt, pointlabel=point_label) | |
| return ann | |
| elif box_prompt is not None: | |
| ann = self.box_prompt(bbox=box_prompt) | |
| return ann | |
| elif text_prompt is not None: | |
| ann = self.text_prompt(text=text_prompt) | |
| return ann | |
| else: | |
| return None | |
| def box_prompt(self, bbox): | |
| assert (bbox[2] != 0 and bbox[3] != 0) | |
| masks = self.image_embedding.masks.data | |
| target_height = self.image.shape[0] | |
| target_width = self.image.shape[1] | |
| h = masks.shape[1] | |
| w = masks.shape[2] | |
| if h != target_height or w != target_width: | |
| bbox = [ | |
| int(bbox[0] * w / target_width), | |
| int(bbox[1] * h / target_height), | |
| int(bbox[2] * w / target_width), | |
| int(bbox[3] * h / target_height), ] | |
| bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 | |
| bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 | |
| bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w | |
| bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h | |
| # IoUs = torch.zeros(len(masks), dtype=torch.float32) | |
| bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) | |
| masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2)) | |
| orig_masks_area = np.sum(masks, axis=(1, 2)) | |
| union = bbox_area + orig_masks_area - masks_area | |
| IoUs = masks_area / union | |
| max_iou_index = np.argmax(IoUs) | |
| return np.array([masks[max_iou_index].cpu().numpy()]) | |
| def point_prompt(self, points, pointlabel): # numpy | |
| masks = self._format_results(self.image_embedding[0], 0) | |
| target_height = self.image.shape[0] | |
| target_width = self.image.shape[1] | |
| h = masks[0]['segmentation'].shape[0] | |
| w = masks[0]['segmentation'].shape[1] | |
| if h != target_height or w != target_width: | |
| points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] | |
| onemask = np.zeros((h, w)) | |
| masks = sorted(masks, key=lambda x: x['area'], reverse=True) | |
| for i, annotation in enumerate(masks): | |
| if type(annotation) == dict: | |
| mask = annotation['segmentation'] | |
| else: | |
| mask = annotation | |
| for i, point in enumerate(points): | |
| if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: | |
| onemask[mask] = 1 | |
| if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: | |
| onemask[mask] = 0 | |
| onemask = onemask >= 1 | |
| return np.array([onemask]) | |
| def _format_results(self, result, filter=0): | |
| annotations = [] | |
| n = len(result.masks.data) | |
| for i in range(n): | |
| annotation = {} | |
| mask = result.masks.data[i] == 1.0 | |
| if np.sum(mask) < filter: | |
| continue | |
| annotation['id'] = i | |
| annotation['segmentation'] = mask | |
| annotation['bbox'] = result.boxes.data[i] | |
| annotation['score'] = result.boxes.conf[i] | |
| annotation['area'] = annotation['segmentation'].sum() | |
| annotations.append(annotation) | |
| return annotations | |