| | import copy |
| | from typing import Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from torch import nn |
| | from transformers import OwlViTConfig |
| | |
| |
|
| | class OwlViTBoxPredictionHead(nn.Module): |
| | def __init__(self, config: OwlViTConfig): |
| | super().__init__() |
| |
|
| | width = config.vision_config.hidden_size |
| | self.dense0 = nn.Linear(width, width) |
| | self.dense1 = nn.Linear(width, width) |
| | self.dense2 = nn.Linear(width, width) |
| | self.dense3 = nn.Linear(width, width) |
| | self.gelu = nn.GELU() |
| | self.dense4 = nn.Linear(width, 4) |
| |
|
| | def forward(self, image_features: torch.Tensor) -> torch.FloatTensor: |
| | output = self.dense0(image_features) |
| | output = self.gelu(output) |
| | output = self.dense1(output) |
| | output = self.gelu(output) |
| | output = self.dense2(output) |
| | output = self.gelu(output) |
| | output = self.dense3(output) |
| | output = self.gelu(output) |
| | output = self.dense4(output) |
| | output = self.gelu(output) |
| |
|
| | return output |
| |
|
| |
|
| |
|
| | class OwlViTClassPredictionHead(nn.Module): |
| | def __init__(self, config: OwlViTConfig): |
| | super().__init__() |
| |
|
| | out_dim = config.text_config.hidden_size |
| | self.query_dim = config.vision_config.hidden_size |
| |
|
| | self.dense0 = nn.Linear(self.query_dim, out_dim) |
| | self.logit_shift = nn.Linear(self.query_dim, 1) |
| | self.logit_scale = nn.Linear(self.query_dim, 1) |
| | self.elu = nn.ELU() |
| |
|
| | def forward( |
| | self, |
| | image_embeds: torch.FloatTensor, |
| | query_embeds: Optional[torch.FloatTensor], |
| | query_mask: Optional[torch.Tensor], |
| | ) -> Tuple[torch.FloatTensor]: |
| | image_class_embeds = self.dense0(image_embeds) |
| | if query_embeds is None: |
| | device = image_class_embeds.device |
| | batch_size, num_patches = image_class_embeds.shape[:2] |
| | pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device) |
| | return (pred_logits, image_class_embeds) |
| |
|
| | |
| | image_class_embeds = F.normalize(image_class_embeds, dim=-1) + 1e-6 |
| | query_embeds = F.normalize(query_embeds, dim=-1) + 1e-6 |
| |
|
| | |
| | pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) |
| |
|
| | |
| | logit_shift = self.logit_shift(image_embeds) |
| | logit_scale = self.logit_scale(image_embeds) |
| | logit_scale = self.elu(logit_scale) + 1 |
| | pred_logits = (pred_logits + logit_shift) * logit_scale |
| |
|
| | if query_mask is not None: |
| | if query_mask.ndim > 1: |
| | query_mask = torch.unsqueeze(query_mask, dim=-2) |
| |
|
| | pred_logits = pred_logits.to(torch.float64) |
| | pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) |
| | pred_logits = pred_logits.to(torch.float32) |
| |
|
| | return (pred_logits, image_class_embeds) |
| |
|
| |
|
| | class OwlViTPredictionHead(nn.Module): |
| | def __init__(self, config: OwlViTConfig, num_classes: int, finetuned: bool): |
| | super().__init__() |
| |
|
| | out_dim = config.text_config.hidden_size |
| | self.query_dim = config.vision_config.hidden_size |
| | self.finetuned = finetuned |
| | self.num_classes = num_classes |
| |
|
| | self.mlp_image = nn.Sequential( |
| | nn.Flatten(), |
| | nn.Linear(in_features=self.query_dim, out_features=self.query_dim), |
| | nn.GELU(), |
| | nn.Linear(in_features=self.query_dim, out_features=self.query_dim), |
| | nn.GELU(), |
| | nn.Linear(in_features=self.query_dim, out_features=out_dim), |
| | nn.GELU(), |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def forward(self, |
| | image_embeds: torch.FloatTensor, |
| | query_embeds: torch.FloatTensor, |
| | topk_idxs: torch.FloatTensor, |
| | ) -> Tuple[torch.FloatTensor]: |
| |
|
| | |
| | topk_idxs = torch.swapaxes(topk_idxs, 1, 2) |
| | one_hot = torch.zeros(topk_idxs.shape[0], topk_idxs.shape[1], image_embeds.shape[1]).to(image_embeds.device).scatter_(2, topk_idxs, 1) |
| | batch_size, n_parts = one_hot.shape[0], one_hot.shape[1] |
| |
|
| | |
| | image_embeds = (one_hot.unsqueeze(-1) * image_embeds.unsqueeze(1)).sum(dim=-2) |
| |
|
| | |
| | image_embeds = self.mlp_image(image_embeds.view(-1, image_embeds.shape[-1])).view(batch_size, n_parts, -1) |
| | query_embeds = query_embeds.view(batch_size, -1, query_embeds.shape[-1]) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | image_embeds = F.normalize(image_embeds, dim=-1) + 1e-6 |
| | query_embeds = F.normalize(query_embeds, dim=-1) + 1e-6 |
| |
|
| | |
| | image_text_logits = torch.einsum('bnd, bid -> bni', image_embeds, query_embeds) |
| | image_text_logits_reshaped = image_text_logits.view(-1, image_text_logits.shape[-1]) |
| |
|
| | |
| | pred_logits = image_text_logits.swapaxes(axis0=1, axis1=2).view(batch_size, self.num_classes, n_parts, -1) |
| | pred_logits = torch.diagonal(pred_logits, dim1=-2, dim2=-1) |
| | |
| | |
| | |
| | |
| | final_pred_logits = torch.sum(pred_logits, dim=-1) |
| |
|
| | return (image_text_logits_reshaped, final_pred_logits, pred_logits) |
| |
|
| |
|
| | class OwlViTForClassification(nn.Module): |
| | config_class = OwlViTConfig |
| |
|
| | def __init__(self, owlvit_det_model, num_classes, weight_dict, device, freeze_box_heads=False, train_box_heads_only=False, network_type=None, logits_from_teacher=False, finetuned: bool = False, custom_box_head: bool = False): |
| | super(OwlViTForClassification, self).__init__() |
| |
|
| | self.config = owlvit_det_model.config |
| | self.num_classes = num_classes |
| | self.num_parts = 12 |
| | self.device = device |
| |
|
| | self.sigmoid = nn.Sigmoid() |
| | self.ce_loss = torch.nn.CrossEntropyLoss() |
| |
|
| | |
| | self.network_type = network_type |
| | self.logits_from_teacher = logits_from_teacher |
| |
|
| | |
| | self.owlvit = copy.deepcopy(owlvit_det_model.owlvit) |
| | self.layer_norm = copy.deepcopy(owlvit_det_model.layer_norm) |
| |
|
| | |
| | self.cls_head = OwlViTPredictionHead(self.config, self.num_classes, finetuned=finetuned) |
| |
|
| | |
| | if custom_box_head: |
| | self.box_head = OwlViTBoxPredictionHead(self.config) |
| | else: |
| | self.box_head = copy.deepcopy(owlvit_det_model.box_head) |
| |
|
| | |
| | |
| | |
| | |
| | self.class_head = OwlViTClassPredictionHead(self.config) |
| | self.class_head.dense0.load_state_dict(owlvit_det_model.class_head.dense0.state_dict()) |
| | self.class_head.logit_shift.load_state_dict(owlvit_det_model.class_head.logit_shift.state_dict()) |
| | self.class_head.logit_scale.load_state_dict(owlvit_det_model.class_head.logit_scale.state_dict()) |
| |
|
| | |
| | |
| |
|
| | |
| | self.weight_dict = weight_dict |
| | losses = ["cardinality"] |
| | losses += ["boxes"] if weight_dict["loss_bbox"] > 0 else [] |
| | losses += ["labels"] if weight_dict["loss_ce"] > 0 else [] |
| |
|
| | self.criterion = DetrLoss( |
| | matcher=None, |
| | num_parts=self.num_parts, |
| | eos_coef=0.1, |
| | losses=losses, |
| | ) |
| |
|
| | self.freeze_parameters(freeze_box_heads, train_box_heads_only) |
| | del owlvit_det_model |
| |
|
| | def freeze_parameters(self, freeze_box_heads, train_box_heads_only): |
| | |
| | for param in self.owlvit.text_model.parameters(): |
| | param.requires_grad = False |
| | for param in self.owlvit.text_projection.parameters(): |
| | param.requires_grad = False |
| |
|
| | |
| | if freeze_box_heads: |
| | for param in self.box_head.parameters(): |
| | param.requires_grad = False |
| | for param in self.class_head.parameters(): |
| | param.requires_grad = False |
| |
|
| | |
| | if train_box_heads_only: |
| | for param in self.owlvit.parameters(): |
| | param.requires_grad = False |
| | for param in self.layer_norm.parameters(): |
| | param.requires_grad = False |
| | for param in self.cls_head.parameters(): |
| | param.requires_grad = False |
| |
|
| | def update_num_classes(self, num_classes): |
| | self.num_classes = num_classes |
| | self.cls_head.num_classes = num_classes |
| |
|
| | def image_text_embedder(self, |
| | input_ids: torch.Tensor, |
| | pixel_values: torch.FloatTensor, |
| | attention_mask: torch.Tensor, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | ) -> Tuple[torch.FloatTensor]: |
| |
|
| | |
| | outputs = self.owlvit( |
| | pixel_values=pixel_values, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=True, |
| | ) |
| |
|
| | |
| | last_hidden_state = outputs.vision_model_output[0] |
| | image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state) |
| |
|
| | |
| | new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) |
| | class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) |
| |
|
| | |
| | image_embeds = image_embeds[:, 1:, :] * class_token_out |
| | image_embeds = self.layer_norm(image_embeds) |
| |
|
| | |
| | new_size = ( |
| | image_embeds.shape[0], |
| | int(np.sqrt(image_embeds.shape[1])), |
| | int(np.sqrt(image_embeds.shape[1])), |
| | image_embeds.shape[-1], |
| | ) |
| | image_embeds = image_embeds.reshape(new_size) |
| | text_embeds = outputs[-4] |
| |
|
| | return (text_embeds, image_embeds, outputs) |
| |
|
| | def image_embedder( |
| | self, |
| | pixel_values: torch.FloatTensor |
| | ) -> Tuple[torch.FloatTensor]: |
| |
|
| | |
| | vision_outputs = self.owlvit.vision_model(pixel_values=pixel_values, return_dict=True) |
| |
|
| | |
| | last_hidden_state = vision_outputs[0] |
| | image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state) |
| |
|
| | |
| | new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) |
| | class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) |
| |
|
| | |
| | image_embeds = image_embeds[:, 1:, :] * class_token_out |
| | image_embeds = self.layer_norm(image_embeds) |
| |
|
| | |
| | new_size = ( |
| | image_embeds.shape[0], |
| | int(np.sqrt(image_embeds.shape[1])), |
| | int(np.sqrt(image_embeds.shape[1])), |
| | image_embeds.shape[-1], |
| | ) |
| | image_embeds = image_embeds.reshape(new_size) |
| |
|
| | return (image_embeds, vision_outputs) |
| |
|
| | def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): |
| | |
| | if not feature_map.ndim == 4: |
| | raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]") |
| |
|
| | device = feature_map.device |
| | num_patches = feature_map.shape[1] |
| |
|
| | box_coordinates = np.stack(np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1).astype(np.float32) |
| | box_coordinates /= np.array([num_patches, num_patches], np.float32) |
| |
|
| | |
| | box_coordinates = box_coordinates.reshape(box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]) |
| | box_coordinates = torch.from_numpy(box_coordinates).to(device) |
| |
|
| | return box_coordinates |
| |
|
| | def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor: |
| | |
| | box_coordinates = self.normalize_grid_corner_coordinates(feature_map) |
| | box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) |
| |
|
| | |
| | box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) |
| |
|
| | |
| | box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2]) |
| | box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) |
| |
|
| | |
| | box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1) |
| | return box_bias |
| |
|
| | def box_predictor( |
| | self, |
| | image_feats: torch.FloatTensor, |
| | feature_map: torch.FloatTensor, |
| | ) -> torch.FloatTensor: |
| | """ |
| | Args: |
| | image_feats: |
| | Features extracted from the image, returned by the `image_text_embedder` method. |
| | feature_map: |
| | A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method. |
| | Returns: |
| | pred_boxes: |
| | List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary. |
| | """ |
| | |
| | pred_boxes = self.box_head(image_feats) |
| |
|
| | |
| | pred_boxes += self.compute_box_bias(feature_map) |
| | pred_boxes = self.sigmoid(pred_boxes) |
| | return pred_boxes |
| |
|
| | def class_predictor( |
| | self, |
| | image_feats: torch.FloatTensor, |
| | query_embeds: Optional[torch.FloatTensor] = None, |
| | query_mask: Optional[torch.Tensor] = None, |
| | ) -> Tuple[torch.FloatTensor]: |
| | """ |
| | Args: |
| | image_feats: |
| | Features extracted from the `image_text_embedder`. |
| | query_embeds: |
| | Text query embeddings. |
| | query_mask: |
| | Must be provided with query_embeddings. A mask indicating which query embeddings are valid. |
| | """ |
| | (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask) |
| |
|
| | return (pred_logits, image_class_embeds) |
| |
|
| | def _get_text_query_mask(self, text_inputs, text_embeds, batch_size: int): |
| | |
| | input_ids = text_inputs["input_ids"] |
| |
|
| | |
| | max_text_queries = input_ids.shape[0] // batch_size |
| | text_embeds = text_embeds.reshape(batch_size, max_text_queries, text_embeds.shape[-1]) |
| |
|
| | |
| | input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1]) |
| | query_mask = input_ids[..., 0] > 0 |
| | return query_mask, text_embeds |
| | |
| |
|
| | def forward(self, image_inputs, text_inputs_parts, text_embeds, targets: dict = None): |
| | |
| | loss_dict = {} |
| |
|
| | if not isinstance(image_inputs, torch.Tensor): |
| | feature_map, _ = self.image_embedder(pixel_values = image_inputs['pixel_values']) |
| | else: |
| | feature_map = image_inputs |
| | batch_size, num_patches, num_patches, hidden_dim = feature_map.shape |
| | image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim)) |
| |
|
| | if self.logits_from_teacher: |
| | teacher_boxes_logits = torch.stack([target["logits"] for target in targets], dim=0).to(self.device) |
| | topk_scores, topk_idxs = torch.topk(teacher_boxes_logits, k=1, dim=1) |
| |
|
| | else: |
| | text_embeds_parts = self.owlvit.get_text_features(**text_inputs_parts) |
| | |
| | |
| | query_mask, text_embeds_parts = self._get_text_query_mask(text_inputs_parts, text_embeds_parts, batch_size) |
| | |
| | |
| | pred_logits_parts, class_embeds = self.class_predictor(image_feats, text_embeds_parts, query_mask) |
| |
|
| | |
| | pred_boxes = self.box_predictor(image_feats, feature_map) |
| | |
| | |
| | scores = self.sigmoid(pred_logits_parts) |
| | topk_scores, topk_idxs = torch.topk(scores, k=1, dim=1) |
| | mapping_indices = [(selected_indices, torch.tensor(list(range(self.num_parts))).to(self.device)) for selected_indices in topk_idxs.squeeze(1)] |
| |
|
| | |
| | selected_idxs = torch.stack([item[0].cpu() for item in mapping_indices]) |
| | loss_dict["pred_boxes"] = torch.gather(pred_boxes.cpu(), 1, selected_idxs.unsqueeze(-1).expand(*selected_idxs.shape, 4)) |
| | |
| | if targets is not None: |
| | |
| | |
| | |
| | outputs_loss = {} |
| | outputs_loss["logits"] = pred_logits_parts |
| | outputs_loss["pred_boxes"] = pred_boxes |
| |
|
| | |
| | loss_dict = self.criterion(outputs_loss, targets, mapping_indices) |
| |
|
| | |
| | logits_per_image = torch.softmax(pred_logits_parts, dim=1) |
| | logits_per_text = torch.softmax(pred_logits_parts, dim=-1) |
| |
|
| | |
| | if self.weight_dict["loss_sym_box_label"] > 0: |
| | sym_loss_box_label = self.loss_symmetric(logits_per_image, logits_per_text, teacher_boxes_logits) |
| | loss_dict["loss_sym_box_label"] = sym_loss_box_label |
| | |
| |
|
| | |
| | image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs) |
| |
|
| | if self.weight_dict["loss_xclip"] > 0: |
| | targets_cls = torch.tensor([target["targets_cls"] for target in targets]).unsqueeze(1).to(self.device) |
| | if self.network_type == "classification": |
| | one_hot = torch.zeros_like(pred_logits).scatter(1, targets_cls, 1).to(self.device) |
| | cls_loss = self.ce_loss(pred_logits, one_hot) |
| | loss_dict["loss_xclip"] = cls_loss |
| | else: |
| | |
| | |
| | logits_per_image = torch.softmax(image_text_logits, dim=0) |
| | logits_per_text = torch.softmax(image_text_logits, dim=-1) |
| | sym_loss = self.loss_symmetric(logits_per_image, logits_per_text, targets_cls) |
| | loss_dict["loss_xclip"] = sym_loss |
| |
|
| | return pred_logits, part_logits, loss_dict |
| |
|
| | def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor: |
| | |
| | |
| | |
| | |
| |
|
| | assert text_logits.shape == image_logits.shape |
| |
|
| | |
| | if image_logits.shape != targets.shape: |
| | batch_size = targets.shape[0] |
| |
|
| | |
| | default_box_labels = torch.kron(torch.ones(batch_size, self.num_classes), torch.eye(self.num_parts)).to(self.device) |
| | if box_labels is None: |
| | box_labels = default_box_labels.clone() |
| | else: |
| | |
| | box_labels = box_labels.view(-1, 1) * default_box_labels |
| |
|
| | |
| | target_one_hot = torch.zeros(batch_size, self.num_classes).to(self.device).scatter(1, targets.view(-1, 1), 1) |
| | target_one_hot = torch.kron(target_one_hot, torch.ones(self.num_parts, self.num_parts).to(self.device)) |
| |
|
| | matching_labels = target_one_hot * box_labels |
| | else: |
| | |
| | values, indices = torch.max(targets, dim=1) |
| | matching_labels = torch.zeros_like(targets).scatter(1, indices.unsqueeze(1), 1) |
| |
|
| | loss_i = F.binary_cross_entropy_with_logits(image_logits, matching_labels, reduction='mean') |
| | loss_t = F.binary_cross_entropy_with_logits(text_logits, matching_labels, reduction='mean') |
| | sym_loss = (loss_i + loss_t).mean() |
| |
|
| | return sym_loss |
| | |
| | class DetrLoss(nn.Module): |
| | """ |
| | This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1) |
| | we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair |
| | of matched ground-truth / prediction (supervise class and box). |
| | |
| | A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes` |
| | parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is |
| | the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to |
| | be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2 |
| | (`max_obj_id` + 1). For more details on this, check the following discussion |
| | https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223" |
| | |
| | |
| | Args: |
| | matcher (`DetrHungarianMatcher`): |
| | Module able to compute a matching between targets and proposals. |
| | num_parts (`int`): |
| | Number of object categories, omitting the special no-object category. |
| | eos_coef (`float`): |
| | Relative classification weight applied to the no-object category. |
| | losses (`List[str]`): |
| | List of all the losses to be applied. See `get_loss` for a list of all available losses. |
| | """ |
| |
|
| | def __init__(self, matcher, num_parts, eos_coef, losses): |
| | super().__init__() |
| | self.matcher = matcher |
| | self.num_parts = num_parts |
| | self.eos_coef = eos_coef |
| | self.losses = losses |
| |
|
| | |
| | empty_weight = torch.ones(self.num_parts) |
| | empty_weight[-1] = self.eos_coef |
| | self.register_buffer("empty_weight", empty_weight) |
| |
|
| | |
| | def loss_labels(self, outputs, targets, indices, num_boxes): |
| | """ |
| | Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim |
| | [nb_target_boxes] |
| | """ |
| | if "logits" not in outputs: |
| | raise KeyError("No logits were found in the outputs") |
| | source_logits = outputs["logits"] |
| |
|
| | idx = self._get_source_permutation_idx(indices) |
| | |
| | |
| | |
| |
|
| | source_logits = source_logits[idx].view(len(indices), -1, self.num_parts) |
| | target_classes = torch.stack([t["class_labels"][J] for t, (_, J) in zip(targets, indices)], dim=0) |
| |
|
| | loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight) |
| | losses = {"loss_ce": loss_ce} |
| |
|
| | return losses |
| |
|
| | @torch.no_grad() |
| | def loss_cardinality(self, outputs, targets, indices, num_boxes): |
| | """ |
| | Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. |
| | |
| | This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients. |
| | """ |
| | logits = outputs["logits"] |
| | device = logits.device |
| | target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) |
| | |
| | card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1) |
| | card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float()) |
| | losses = {"cardinality_error": card_err} |
| | return losses |
| |
|
| | def loss_boxes(self, outputs, targets, indices, num_boxes): |
| | """ |
| | Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. |
| | |
| | Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes |
| | are expected in format (center_x, center_y, w, h), normalized by the image size. |
| | """ |
| | if "pred_boxes" not in outputs: |
| | raise KeyError("No predicted boxes found in outputs") |
| |
|
| | idx = self._get_source_permutation_idx(indices) |
| | source_boxes = outputs["pred_boxes"][idx] |
| | target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) |
| |
|
| | losses = {} |
| |
|
| | loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none") |
| | losses["loss_bbox"] = loss_bbox.sum() / num_boxes |
| |
|
| | loss_giou = 1 - torch.diag(generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))) |
| | losses["loss_giou"] = loss_giou.sum() / num_boxes |
| |
|
| | return losses |
| |
|
| | def loss_masks(self, outputs, targets, indices, num_boxes): |
| | """ |
| | Compute the losses related to the masks: the focal loss and the dice loss. |
| | |
| | Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]. |
| | """ |
| | if "pred_masks" not in outputs: |
| | raise KeyError("No predicted masks found in outputs") |
| |
|
| | source_idx = self._get_source_permutation_idx(indices) |
| | target_idx = self._get_target_permutation_idx(indices) |
| | source_masks = outputs["pred_masks"] |
| | source_masks = source_masks[source_idx] |
| | masks = [t["masks"] for t in targets] |
| |
|
| | |
| | target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() |
| | target_masks = target_masks.to(source_masks) |
| | target_masks = target_masks[target_idx] |
| |
|
| | |
| | source_masks = nn.functional.interpolate( |
| | source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False |
| | ) |
| | source_masks = source_masks[:, 0].flatten(1) |
| |
|
| | target_masks = target_masks.flatten(1) |
| | target_masks = target_masks.view(source_masks.shape) |
| | losses = { |
| | "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes), |
| | "loss_dice": dice_loss(source_masks, target_masks, num_boxes), |
| | } |
| | return losses |
| |
|
| | def _get_source_permutation_idx(self, indices): |
| | |
| | batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)]) |
| | source_idx = torch.cat([source for (source, _) in indices]) |
| | return batch_idx, source_idx |
| |
|
| | def _get_target_permutation_idx(self, indices): |
| | |
| | batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)]) |
| | target_idx = torch.cat([target for (_, target) in indices]) |
| | return batch_idx, target_idx |
| |
|
| | def get_loss(self, loss, outputs, targets, indices, num_boxes): |
| | loss_map = { |
| | "labels": self.loss_labels, |
| | "cardinality": self.loss_cardinality, |
| | "boxes": self.loss_boxes, |
| | "masks": self.loss_masks, |
| | } |
| | if loss not in loss_map: |
| | raise ValueError(f"Loss {loss} not supported") |
| | return loss_map[loss](outputs, targets, indices, num_boxes) |
| |
|
| | def forward(self, outputs, targets, indices): |
| | """ |
| | This performs the loss computation. |
| | |
| | Args: |
| | outputs (`dict`, *optional*): |
| | Dictionary of tensors, see the output specification of the model for the format. |
| | targets (`List[dict]`, *optional*): |
| | List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the |
| | losses applied, see each loss' doc. |
| | """ |
| | outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"} |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | num_boxes = sum(len(t["class_labels"]) for t in targets) |
| | num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) |
| | |
| | |
| | |
| | |
| | num_boxes = torch.clamp(num_boxes, min=1).item() |
| |
|
| | |
| | losses = {} |
| | for loss in self.losses: |
| | losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) |
| |
|
| | |
| | if "auxiliary_outputs" in outputs: |
| | for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): |
| | |
| | for loss in self.losses: |
| | if loss == "masks": |
| | |
| | continue |
| | l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes) |
| | l_dict = {k + f"_{i}": v for k, v in l_dict.items()} |
| | losses.update(l_dict) |
| |
|
| | return losses |