wwu-577/src/model.py

17 lines
516 B
Python

from transformers import AutoTokenizer, RobertaModel
import torch
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.n_classes = 10
self.bert = RobertaModel.from_pretrained("roberta-base")
self.linear = torch.nn.Linear(self.bert.config.hidden_size, self.n_classes)
self.act = torch.nn.Sigmoid()
def forward(self, x):
outs = self.bert(**x)
outs = self.act(self.linear(outs.last_hidden_state))
return outs