17 lines
516 B
Python
17 lines
516 B
Python
from transformers import AutoTokenizer, RobertaModel
|
|
import torch
|
|
from torch import nn
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.n_classes = 10
|
|
self.bert = RobertaModel.from_pretrained("roberta-base")
|
|
self.linear = torch.nn.Linear(self.bert.config.hidden_size, self.n_classes)
|
|
self.act = torch.nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
outs = self.bert(**x)
|
|
outs = self.act(self.linear(outs.last_hidden_state))
|
|
return outs
|