Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from typing import List | |
| from collections import OrderedDict | |
| from . import _utils as utils | |
| class EncoderMixin: | |
| """Add encoder functionality such as: | |
| - output channels specification of feature tensors (produced by encoder) | |
| - patching first convolution for arbitrary input channels | |
| """ | |
| def out_channels(self): | |
| """Return channels dimensions for each tensor of forward output of encoder""" | |
| return self._out_channels[: self._depth + 1] | |
| def set_in_channels(self, in_channels, pretrained=True): | |
| """Change first convolution channels""" | |
| if in_channels == 3: | |
| return | |
| self._in_channels = in_channels | |
| if self._out_channels[0] == 3: | |
| self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) | |
| utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained) | |
| def get_stages(self): | |
| """Method should be overridden in encoder""" | |
| raise NotImplementedError | |
| def make_dilated(self, output_stride): | |
| if output_stride == 16: | |
| stage_list=[5,] | |
| dilation_list=[2,] | |
| elif output_stride == 8: | |
| stage_list=[4, 5] | |
| dilation_list=[2, 4] | |
| else: | |
| raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride)) | |
| stages = self.get_stages() | |
| for stage_indx, dilation_rate in zip(stage_list, dilation_list): | |
| utils.replace_strides_with_dilation( | |
| module=stages[stage_indx], | |
| dilation_rate=dilation_rate, | |
| ) | |