from torch import nn class Classifier(nn.Module): def __init__(self, embedding_length: int, classes: int): super().__init__() out_len = 16 self.stack = nn.Sequential( nn.Linear(embedding_length, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, out_len), nn.ReLU(), ) self.logits = nn.Linear(out_len, classes) def forward(self, x): x = self.stack(x) self.last_hidden_layer = x.detach() return self.logits(x) def get_last_layer(self, x): x = self.stack(x) return x