87 lines
2.6 KiB
Python
87 lines
2.6 KiB
Python
import click
|
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from data.main import connect, paths
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import pandas as pd
|
|
|
|
@click.option('-c', '--chunks', type=int, default=500, show_default=True)
|
|
@click.command("sentiment:extract")
|
|
def extract(chunks):
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
chunks = 1000
|
|
|
|
# Load model from HuggingFace Hub
|
|
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
|
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
|
|
model = model.to(device)
|
|
|
|
|
|
# load data
|
|
DB = connect()
|
|
table = DB.sql("""
|
|
select
|
|
id
|
|
,title
|
|
from stories
|
|
order by id desc
|
|
""").df()
|
|
DB.close()
|
|
|
|
# normalize text
|
|
table['title'] = table['title'].str.normalize('NFKD').str.encode('ascii', errors='ignore').str.decode('utf-8')
|
|
|
|
|
|
chunked = np.array_split(table, chunks)
|
|
|
|
# generate embeddings from list of titles
|
|
iterator = tqdm(chunked, 'embedding')
|
|
sentiments = []
|
|
story_ids = []
|
|
for _, chunk in enumerate(iterator):
|
|
sentences = chunk['title'].tolist()
|
|
ids = chunk['id'].tolist()
|
|
# Tokenize sentences
|
|
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
|
# Compute token embeddings
|
|
with torch.no_grad():
|
|
logits = model(**encoded_input.to(device)).logits
|
|
sentiment = logits.argmax(axis=1).tolist()
|
|
sentiments.append(sentiment)
|
|
story_ids.append(ids)
|
|
|
|
sentiments = np.concatenate(sentiments)
|
|
story_ids = np.concatenate(story_ids)
|
|
|
|
# save embeddings
|
|
save_to = data_dir() / 'sentiment.npy'
|
|
np.save(save_to, sentiments)
|
|
print(f"sentiments saved: {save_to}")
|
|
|
|
# save ids
|
|
save_to = data_dir() / 'sentiment_ids.npy'
|
|
np.save(save_to, story_ids)
|
|
print(f"ids saved: {save_to}")
|
|
|
|
@click.command('sentiment:load')
|
|
def load():
|
|
|
|
sentiments = np.load(paths('data') / 'sentiment.npy')
|
|
story_ids = np.load(paths('data') / 'sentiment_ids.npy')
|
|
data = pd.DataFrame(story_ids, columns=['story_id']).reset_index()
|
|
data['sentiment_id'] = sentiments
|
|
|
|
with connect() as db:
|
|
db.query("""
|
|
CREATE OR REPLACE TABLE story_sentiments AS
|
|
SELECT
|
|
data.story_id
|
|
,data.sentiment_id as class_id
|
|
,CASE WHEN data.sentiment_id = 1 THEN 'positive' ELSE 'negative' end as label
|
|
FROM data
|
|
JOIN stories s
|
|
ON s.id = data.story_id
|
|
""")
|