Spaces:
Runtime error
Runtime error
| import timm | |
| import torch.nn as nn | |
| class TimmUniversalEncoder(nn.Module): | |
| def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32): | |
| super().__init__() | |
| kwargs = dict( | |
| in_chans=in_channels, | |
| features_only=True, | |
| output_stride=output_stride, | |
| pretrained=pretrained, | |
| out_indices=tuple(range(depth)), | |
| ) | |
| # not all models support output stride argument, drop it by default | |
| if output_stride == 32: | |
| kwargs.pop("output_stride") | |
| self.model = timm.create_model(name, **kwargs) | |
| self._in_channels = in_channels | |
| self._out_channels = [in_channels, ] + self.model.feature_info.channels() | |
| self._depth = depth | |
| def forward(self, x): | |
| features = self.model(x) | |
| features = [x,] + features | |
| return features | |
| def out_channels(self): | |
| return self._out_channels | |