| |
|
|
| |
|
|
| |
| |
|
|
|
|
| import numpy.random as random |
| import matplotlib.pyplot as plt |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.autograd import Variable |
| from math import exp |
|
|
| class SiLogLoss(nn.Module): |
| def __init__(self, lambd=0.5, eps=1e-6): |
| super().__init__() |
| self.lambd = lambd |
| self.eps = eps |
|
|
| def forward(self, pred, target): |
| |
| pred = pred.float() |
| target = target.float() |
| |
| |
| diff_log = torch.log(target + self.eps) - torch.log(pred + self.eps) |
| loss = torch.sqrt( |
| (diff_log ** 2).mean() - self.lambd * (diff_log.mean() ** 2) + self.eps |
| ) |
| return loss |
|
|
| class IntegrityPriorLoss(nn.Module): |
| def __init__(self, epsilon=1e-8): |
| super().__init__() |
|
|
| self.epsilon = epsilon |
| self.max_variance = 0.05 |
| self.max_grad = 0.05 |
|
|
| self.sobel_x = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False) |
| self.sobel_y = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False) |
| |
| sobel_kernel_x = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], dtype=torch.float32) |
| sobel_kernel_y = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], dtype=torch.float32) |
| |
| self.sobel_x.weight.data = sobel_kernel_x |
| self.sobel_y.weight.data = sobel_kernel_y |
| |
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| def forward(self, mask, depth_map, gt): |
| |
| |
| py = gt*mask + (1-gt)*(1-mask) |
| FP = (1-py)*mask |
| FN = (1-py)*gt |
| logP = -torch.log(py + self.epsilon) |
| diff = (depth_map-((depth_map*gt).sum()/gt.sum()))**2 |
| FPdiff = (diff)*FP |
| FNdiff = (1-diff)*FN |
| vareight = (FPdiff+FNdiff)*py |
| variance = logP * vareight |
| variance_loss = torch.mean(variance) |
|
|
| grad_x = abs(self.sobel_x(depth_map)) |
| grad_y = abs(self.sobel_y(depth_map)) |
| |
| masked_grad_x = grad_x * logP |
| masked_grad_y = grad_y * logP |
| |
| grad = (masked_grad_x + masked_grad_y) |
| grad_loss = torch.mean(grad) |
|
|
| total_loss = variance_loss + grad_loss |
| return total_loss |
| def gaussian(window_size, sigma): |
| gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) |
| return gauss/gauss.sum() |
|
|
| def create_window(window_size, channel): |
| _1D_window = gaussian(window_size, 1.5).unsqueeze(1) |
| _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) |
| window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) |
| return window |
|
|
| def _ssim(img1, img2, window, window_size, channel, size_average=True): |
| mu1 = F.conv2d(img1, window, padding = window_size//2, groups=channel) |
| mu2 = F.conv2d(img2, window, padding = window_size//2, groups=channel) |
|
|
| mu1_sq = mu1.pow(2) |
| mu2_sq = mu2.pow(2) |
| mu1_mu2 = mu1*mu2 |
|
|
| sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq |
| sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq |
| sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2 |
|
|
| C1 = 0.01**2 |
| C2 = 0.03**2 |
|
|
| ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) |
|
|
| if size_average: |
| return ssim_map.mean() |
| else: |
| return ssim_map.mean(1).mean(1).mean(1) |
|
|
| class SSIMLoss(torch.nn.Module): |
| def __init__(self, window_size=11, size_average=True): |
| super(SSIMLoss, self).__init__() |
| self.window_size = window_size |
| self.size_average = size_average |
| self.channel = 1 |
| self.window = create_window(window_size, self.channel) |
|
|
| def forward(self, img1, img2): |
| (_, channel, _, _) = img1.size() |
| if channel == self.channel and self.window.data.type() == img1.data.type(): |
| window = self.window |
| else: |
| window = create_window(self.window_size, channel) |
| if img1.is_cuda: |
| window = window.cuda(img1.get_device()) |
| window = window.type_as(img1) |
| self.window = window |
| self.channel = channel |
| return 1 - (1 + _ssim(img1, img2, window, self.window_size, channel, self.size_average)) / 2 |
|
|
| def circular_highPassFiltering(img, ratio): |
| device = img.device |
| batch_size,_,height,width = img.shape |
| sigma = (height * (ratio[...,None,None])) / 4 |
| center_h = height // 2 |
| center_w = width // 2 |
| grid_y, grid_x = torch.meshgrid(torch.arange(-center_h, height - center_h), |
| torch.arange(-center_w, width - center_w)) |
| grid_y = grid_y[None,None,...].repeat(batch_size, 1, 1, 1).to(device) |
| grid_x = grid_x[None,None,...].repeat(batch_size, 1, 1, 1).to(device) |
| |
| gaussian_values = (1 / (2 * torch.pi * sigma ** 2)) * torch.exp(-(grid_x ** 2 + grid_y ** 2) / (2 * sigma ** 2)) |
| gmin = gaussian_values.flatten(-2).min(dim=-1)[0][...,None,None] |
| gmax = gaussian_values.flatten(-2).max(dim=-1)[0][...,None,None] |
| decreasing_matrix = (gaussian_values-gmin) / (gmax-gmin) |
| mask = ((0.5-decreasing_matrix)*100).sigmoid() |
| fft = torch.fft.fft2(img) |
| fft_shift = torch.fft.fftshift(fft,dim=(2,3)) |
| fft_shift = torch.mul(fft_shift, mask) |
| idft_shift = torch.fft.ifftshift(fft_shift,dim=(2,3)) |
| ifimg = torch.fft.ifft2(idft_shift) |
| ifimg = torch.abs(ifimg) |
| ifmin = ifimg.flatten(-2).min(dim=-1)[0][...,None,None] |
| ifmax = ifimg.flatten(-2).max(dim=-1)[0][...,None,None] |
| ifimg = (ifimg-ifmin) / (ifmax-ifmin) |
| return mask,ifimg |
|
|
| def _upsample_like(src,tar,mode='bilinear'): |
| if mode == 'bilinear': |
| src = F.upsample(src,size=tar.shape[2:],mode=mode,align_corners=True) |
| else: |
| src = F.upsample(src,size=tar.shape[2:],mode=mode) |
| return src |
|
|
| def _upsample_(src,size,mode='bilinear'): |
| if mode == 'bilinear': |
| src = F.upsample(src,size=size,mode=mode,align_corners=True) |
| else: |
| src = F.upsample(src,size=size,mode=mode) |
| return src |
|
|
| def patchfy(x,p=4,c=4): |
| h = w = x.shape[2] // p |
| x = x.reshape(shape=(x.shape[0], c, h, p, w, p)) |
| x = torch.einsum('nchpwq->nhwpqc', x) |
| x = x.reshape(shape=(x.shape[0], h * w, p**2 * c)) |
| return x |
| |
| def unpatchfy(x,p=4,c=4): |
| h = w = round(x.shape[1]**0.5) |
| x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) |
| x = torch.einsum('nhwpqc->nchpwq', x) |
| x = x.reshape(shape=(x.shape[0], c, h * p, h * p)) |
| return x |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def structure_loss(pred, mask): |
| weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask) |
| wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none') |
| wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3)) |
|
|
|
|
| pred = torch.sigmoid(pred) |
| inter = ((pred * mask) * weit).sum(dim=(2, 3)) |
|
|
| union = ((pred + mask) * weit).sum(dim=(2, 3)) |
| wiou = 1-(inter+1)/(union-inter+1) |
|
|
| return (wbce+wiou).mean() |
|
|
| def iou_loss(pred, mask): |
| eps = 1e-6 |
| inter = (pred * mask).sum(dim=(2, 3)) |
| union = (pred + mask).sum(dim=(2, 3)) - inter |
| iou = 1 - (inter + eps) / (union + eps) |
| return iou.mean() |
|
|
| def dice_loss(pred, mask): |
| eps = 1e-6 |
| N = pred.size()[0] |
| pred_flat = pred.view(N,-1) |
| mask_flat = mask.view(N,-1) |
|
|
| intersection = (pred_flat * mask_flat).sum(1) |
| dice_coefficient = (2. * intersection + eps) / (pred_flat.sum(1) + mask_flat.sum(1) + eps) |
| dice_loss_value = 1 - dice_coefficient.sum()/N |
| return dice_loss_value |
|
|
| class LargeK(nn.Module): |
| """ LargeK Block. |
| |
| Args: |
| dim (int): Number of input channels. |
| drop_path (float): Stochastic depth rate. Default: 0.0 |
| """ |
| def __init__(self, dim): |
| super().__init__() |
| self.channel_split = nn.Conv2d(dim,dim*3,kernel_size=1) |
| self.dwconv1 = nn.Conv2d(dim, dim, kernel_size=7, dilation=2, padding=6, groups=dim) |
| self.dwconv2 = nn.Conv2d(dim, dim, kernel_size=7, dilation=4, padding=12, groups=dim) |
| self.dwconv3 = nn.Conv2d(dim, dim, kernel_size=7, dilation=8, padding=24, groups=dim) |
| self.channel_mix = nn.Conv2d(dim*3,dim,kernel_size=1) |
|
|
| def forward(self, x): |
| x = self.channel_split(x) |
| x1,x2,x3 = torch.chunk(x,3,dim=1) |
| x1 = self.dwconv1(x1) |
| x2 = self.dwconv2(x2) |
| x3 = self.dwconv3(x3) |
| x = torch.cat([x1,x2,x3],dim=1) |
| x = self.channel_mix(x) |
| return x |
|
|
| class GANLoss(nn.Module): |
| """Define different GAN objectives. |
| |
| The GANLoss class abstracts away the need to create the target label tensor |
| that has the same size as the input. |
| """ |
|
|
| def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): |
| """ Initialize the GANLoss class. |
| |
| Parameters: |
| gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. |
| target_real_label (bool) - - label for a real image |
| target_fake_label (bool) - - label of a fake image |
| |
| Note: Do not use sigmoid as the last layer of Discriminator. |
| LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. |
| """ |
| super(GANLoss, self).__init__() |
| self.register_buffer('real_label', torch.tensor(target_real_label)) |
| self.register_buffer('fake_label', torch.tensor(target_fake_label)) |
| self.gan_mode = gan_mode |
| if gan_mode == 'lsgan': |
| self.loss = nn.MSELoss() |
| elif gan_mode == 'vanilla': |
| self.loss = nn.BCEWithLogitsLoss() |
| elif gan_mode in ['wgangp']: |
| self.loss = None |
| else: |
| raise NotImplementedError('gan mode %s not implemented' % gan_mode) |
|
|
| def get_target_tensor(self, prediction, target_is_real): |
| """Create label tensors with the same size as the input. |
| |
| Parameters: |
| prediction (tensor) - - tpyically the prediction from a discriminator |
| target_is_real (bool) - - if the ground truth label is for real images or fake images |
| |
| Returns: |
| A label tensor filled with ground truth label, and with the size of the input |
| """ |
|
|
| if target_is_real: |
| target_tensor = self.real_label |
| else: |
| target_tensor = self.fake_label |
| return target_tensor.expand_as(prediction) |
|
|
| def __call__(self, prediction, target_is_real): |
| """Calculate loss given Discriminator's output and grount truth labels. |
| |
| Parameters: |
| prediction (tensor) - - tpyically the prediction output from a discriminator |
| target_is_real (bool) - - if the ground truth label is for real images or fake images |
| |
| Returns: |
| the calculated loss. |
| """ |
| if self.gan_mode in ['lsgan', 'vanilla']: |
| target_tensor = self.get_target_tensor(prediction, target_is_real) |
| loss = self.loss(prediction, target_tensor) |
| elif self.gan_mode == 'wgangp': |
| if target_is_real: |
| loss = -prediction.mean() |
| else: |
| loss = prediction.mean() |
| return loss |
|
|
| def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device): |
| |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1) |
| |
| ids = torch.arange(0, output_dim // 2, dtype=torch.float) |
| theta = torch.pow(10000, -2 * ids / output_dim) |
|
|
| |
| embeddings = position * theta |
|
|
| |
| embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) |
|
|
| |
| embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape)))) |
|
|
| |
| |
| embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim)) |
| embeddings = embeddings.to(device) |
| return embeddings |
|
|
| def RoPE(q, k): |
| |
| use_multi_head = True |
| if q.size() == 3 and k.size() == 3: |
| use_multi_head = False |
| q, k = q[:,None,...], k[:,None,...] |
| batch_size = q.shape[0] |
| nums_head = q.shape[1] |
| max_len = q.shape[2] |
| output_dim = q.shape[-1] |
|
|
| |
| pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device) |
|
|
|
|
| |
| |
| cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) |
| sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) |
|
|
| |
| q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) |
| q2 = q2.reshape(q.shape) |
|
|
| |
| q = q * cos_pos + q2 * sin_pos |
|
|
| k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1) |
| k2 = k2.reshape(k.shape) |
| |
| k = k * cos_pos + k2 * sin_pos |
| if not use_multi_head: |
| q, k = q[:,0], k[:,0] |
| return q, k |
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, hidden_size: int, intermediate_size: int) -> None: |
| super().__init__() |
| self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False) |
| self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| |
| |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
| class LayerNorm(nn.Module): |
| """ LayerNorm that supports two data formats: channels_last (default) or channels_first. |
| The ordering of the dimensions in the inputs. channels_last corresponds to inputs with |
| shape (batch_size, height, width, channels) while channels_first corresponds to inputs |
| with shape (batch_size, channels, height, width). |
| """ |
| def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
| self.eps = eps |
| self.data_format = data_format |
| if self.data_format not in ["channels_last", "channels_first"]: |
| raise NotImplementedError |
| self.normalized_shape = (normalized_shape, ) |
| |
| def forward(self, x): |
| if self.data_format == "channels_last": |
| return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
| elif self.data_format == "channels_first": |
| u = x.mean(1, keepdim=True) |
| s = (x - u).pow(2).mean(1, keepdim=True) |
| x = (x - u) / torch.sqrt(s + self.eps) |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] |
| return x |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, hidden_size: int, eps: float = 1e-6, data_format="channels_first") -> None: |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.data_format = data_format |
| |
| def _norm(self, hidden_states): |
| if self.data_format == "channels_first": |
| variance = hidden_states.pow(2).mean(dim=(1), keepdim=True) |
| elif self.data_format == "channels_last": |
| variance = hidden_states.pow(2).mean(dim=-1, keepdim=True) |
| return hidden_states * torch.rsqrt(variance + self.eps) |
| |
| def forward(self, hidden_states): |
| if self.data_format == "channels_first": |
| return self.weight[..., None, None] * self._norm(hidden_states.float()).type_as(hidden_states) |
| elif self.data_format == "channels_last": |
| return self.weight * self._norm(hidden_states.float()).type_as(hidden_states) |
|
|
| class GRN(nn.Module): |
| """ GRN (Global Response Normalization) layer |
| """ |
| def __init__(self, dim): |
| super().__init__() |
| self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
| self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
|
| def forward(self, x): |
| Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) |
| Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) |
| return self.gamma * (x * Nx) + self.beta + x |
| |
| class DUpsampling(nn.Module): |
| def __init__(self, inplanes, scale, pad=0): |
| super(DUpsampling, self).__init__() |
| self.conv1 = nn.Conv2d(inplanes, inplanes* scale * scale, kernel_size=1, padding = pad) |
| self.scale = scale |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| N, C, H, W = x.size() |
| |
| x_permuted = x.permute(0, 2, 3, 1) |
|
|
| |
| x_permuted = x_permuted.contiguous().view((N, H, W * self.scale, int(C / (self.scale)))) |
|
|
| |
| x_permuted = x_permuted.permute(0, 2, 1, 3) |
| |
| x_permuted = x_permuted.contiguous().view((N, W * self.scale, H * self.scale, int(C / (self.scale * self.scale)))) |
|
|
| |
| x = x_permuted.permute(0, 3, 2, 1) |
| |
| return x |
| |
| class REsampling(nn.Module): |
| def __init__(self, scale): |
| super(REsampling, self).__init__() |
| self.scale = scale |
| |
| def forward(self, x): |
| N, C, H, W = x.size() |
| |
| x_permuted = x.permute(0, 2, 3, 1) |
|
|
| |
| x_permuted = x_permuted.contiguous().view((N, H, W * self.scale, int(C / (self.scale)))) |
|
|
| |
| x_permuted = x_permuted.permute(0, 2, 1, 3) |
| |
| x_permuted = x_permuted.contiguous().view((N, W * self.scale, H * self.scale, int(C / (self.scale * self.scale)))) |
|
|
| |
| x = x_permuted.permute(0, 3, 2, 1) |
| |
| return x |
| |
| class Dcrop(nn.Module): |
| def __init__(self,inplanes,cropscale=2): |
| super(Dcrop, self).__init__() |
| self.conv = nn.Conv2d(inplanes*cropscale*cropscale, inplanes*cropscale*cropscale, kernel_size=3, padding = 1) |
| self.cropscale = cropscale |
| |
| def forward(self, x): |
| B,C,H,W = x.size() |
| x_permuted = x.permute(0, 2, 3, 1) |
| x_permuted = x_permuted.contiguous().view((B, H, W//self.cropscale, C*self.cropscale)) |
| x_permuted = x_permuted.permute(0, 2, 1, 3) |
| x_permuted = x_permuted.contiguous().view((B, W//self.cropscale, H//self.cropscale, C*self.cropscale*self.cropscale)) |
| x = x_permuted.permute(0, 3, 2, 1) |
| x = self.conv(x)+x |
| return x |
| |
| def show_gray_images(images, m=8, alpha=3, cmap='coolwarm',save_path=None): |
| if len(images.size()) == 2: |
| plt.imshow(images, cmap=cmap) |
| plt.axis('off') |
| else: |
| n, h, w = images.shape |
| if n == 1: |
| plt.imshow(images[0], cmap=cmap) |
| plt.axis('off') |
| else: |
| if m > n: m = n |
| num_rows = (n + m - 1) // m |
| fig, axes = plt.subplots(num_rows, m, figsize=(m * 2*alpha, num_rows * 2*alpha)) |
| plt.subplots_adjust(wspace=0.05, hspace=0.05) |
| for i in range(num_rows): |
| for j in range(m): |
| idx = i*m + j |
| if m == 1 or num_rows == 1: |
| axes[idx].imshow(images[idx], cmap=cmap) |
| axes[idx].axis('off') |
| elif idx < n: |
| axes[i, j].imshow(images[idx], cmap=cmap) |
| axes[i, j].axis('off') |
| if save_path is not None: |
| plt.savefig(save_path) |
| plt.close() |
| else: |
| plt.show() |
| |
|
|
|
|
|
|