29 lines
753 B
Python
29 lines
753 B
Python
from torch import nn
|
|
|
|
class Classifier(nn.Module):
|
|
def __init__(self, embedding_length: int, classes: int):
|
|
super().__init__()
|
|
out_len = 16
|
|
self.stack = nn.Sequential(
|
|
nn.Linear(embedding_length, 256),
|
|
nn.ReLU(),
|
|
nn.Linear(256, 256),
|
|
nn.ReLU(),
|
|
nn.Linear(256, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, out_len),
|
|
nn.ReLU(),
|
|
)
|
|
self.logits = nn.Linear(out_len, classes)
|
|
|
|
def forward(self, x):
|
|
x = self.stack(x)
|
|
self.last_hidden_layer = x.detach()
|
|
return self.logits(x)
|
|
|
|
def get_last_layer(self, x):
|
|
x = self.stack(x)
|
|
return x
|