Spaces:
Runtime error
Runtime error
Upload encoders/_utils.py
Browse files- encoders/_utils.py +59 -0
encoders/_utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True):
|
| 6 |
+
"""Change first convolution layer input channels.
|
| 7 |
+
In case:
|
| 8 |
+
in_channels == 1 or in_channels == 2 -> reuse original weights
|
| 9 |
+
in_channels > 3 -> make random kaiming normal initialization
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
# get first conv
|
| 13 |
+
for module in model.modules():
|
| 14 |
+
if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels:
|
| 15 |
+
break
|
| 16 |
+
|
| 17 |
+
weight = module.weight.detach()
|
| 18 |
+
module.in_channels = new_in_channels
|
| 19 |
+
|
| 20 |
+
if not pretrained:
|
| 21 |
+
module.weight = nn.parameter.Parameter(
|
| 22 |
+
torch.Tensor(
|
| 23 |
+
module.out_channels,
|
| 24 |
+
new_in_channels // module.groups,
|
| 25 |
+
*module.kernel_size
|
| 26 |
+
)
|
| 27 |
+
)
|
| 28 |
+
module.reset_parameters()
|
| 29 |
+
|
| 30 |
+
elif new_in_channels == 1:
|
| 31 |
+
new_weight = weight.sum(1, keepdim=True)
|
| 32 |
+
module.weight = nn.parameter.Parameter(new_weight)
|
| 33 |
+
|
| 34 |
+
else:
|
| 35 |
+
new_weight = torch.Tensor(
|
| 36 |
+
module.out_channels,
|
| 37 |
+
new_in_channels // module.groups,
|
| 38 |
+
*module.kernel_size
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
for i in range(new_in_channels):
|
| 42 |
+
new_weight[:, i] = weight[:, i % default_in_channels]
|
| 43 |
+
|
| 44 |
+
new_weight = new_weight * (default_in_channels / new_in_channels)
|
| 45 |
+
module.weight = nn.parameter.Parameter(new_weight)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def replace_strides_with_dilation(module, dilation_rate):
|
| 49 |
+
"""Patch Conv2d modules replacing strides with dilation"""
|
| 50 |
+
for mod in module.modules():
|
| 51 |
+
if isinstance(mod, nn.Conv2d):
|
| 52 |
+
mod.stride = (1, 1)
|
| 53 |
+
mod.dilation = (dilation_rate, dilation_rate)
|
| 54 |
+
kh, kw = mod.kernel_size
|
| 55 |
+
mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate)
|
| 56 |
+
|
| 57 |
+
# Kostyl for EfficientNet
|
| 58 |
+
if hasattr(mod, "static_padding"):
|
| 59 |
+
mod.static_padding = nn.Identity()
|