wwu-577/src/train/model.py

29 lines
753 B
Python

from torch import nn
class Classifier(nn.Module):
def __init__(self, embedding_length: int, classes: int):
super().__init__()
out_len = 16
self.stack = nn.Sequential(
nn.Linear(embedding_length, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, out_len),
nn.ReLU(),
)
self.logits = nn.Linear(out_len, classes)
def forward(self, x):
x = self.stack(x)
self.last_hidden_layer = x.detach()
return self.logits(x)
def get_last_layer(self, x):
x = self.stack(x)
return x