File size: 3,038 Bytes
9060565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
"""
import torch
from torch import nn

try:
    import segmentation_models_pytorch as smp

    class Transformer(nn.Module):
        """ Transformer encoder from segformer paper with MAnet decoder 
            (configuration from MEDIAR)
        """

        def __init__(self, encoder="mit_b5", encoder_weights=None, decoder="MAnet",
                     diam_mean=30.):
            super().__init__()
            net_fcn = smp.MAnet if decoder == "MAnet" else smp.FPN
            self.encoder = encoder
            self.decoder = decoder
            self.net = net_fcn(
                encoder_name=encoder,
                encoder_weights=encoder_weights,
                # (use "imagenet" pre-trained weights for encoder initialization if training)
                in_channels=3,
                classes=3,
                activation=None)
            self.nout = 3
            self.mkldnn = False
            self.diam_mean = nn.Parameter(data=torch.ones(1) * diam_mean,
                                          requires_grad=False)
            self.diam_labels = nn.Parameter(data=torch.ones(1) * diam_mean,
                                            requires_grad=False)

        def forward(self, X):
            # have to convert to 3-chan (RGB)
            if X.shape[1] < 3:
                X = torch.cat(
                    (X,
                     torch.zeros((X.shape[0], 3 - X.shape[1], X.shape[2], X.shape[3]),
                                 device=X.device)), dim=1)
            y = self.net(X)
            return y, torch.zeros((X.shape[0], 256), device=X.device)

        @property
        def device(self):
            return next(self.parameters()).device

        def save_model(self, filename):
            """
            Save the model to a file.

            Args:
                filename (str): The path to the file where the model will be saved.
            """
            torch.save(self.state_dict(), filename)

        def load_model(self, filename, device=None):
            """
            Load the model from a file.

            Args:
                filename (str): The path to the file where the model is saved.
                device (torch.device, optional): The device to load the model on. Defaults to None.
            """
            if (device is not None) and (device.type != "cpu"):
                state_dict = torch.load(filename, map_location=device, weights_only=True)
            else:
                self.__init__(encoder=self.encoder, decoder=self.decoder,
                              diam_mean=self.diam_mean)
                state_dict = torch.load(filename, map_location=torch.device("cpu"), weights_only=True)

            self.load_state_dict(
                dict([(name, param) for name, param in state_dict.items()]),
                strict=False)

except Exception as e:
    print(e)
    print("need to install segmentation_models_pytorch to run transformer")