wwu-577/src/data/sentiment.py

87 lines
2.6 KiB
Python

import click
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import torch.nn.functional as F
from data.main import connect, paths
import numpy as np
from tqdm import tqdm
import pandas as pd
@click.option('-c', '--chunks', type=int, default=500, show_default=True)
@click.command("sentiment:extract")
def extract(chunks):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
chunks = 1000
# Load model from HuggingFace Hub
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
model = model.to(device)
# load data
DB = connect()
table = DB.sql("""
select
id
,title
from stories
order by id desc
""").df()
DB.close()
# normalize text
table['title'] = table['title'].str.normalize('NFKD').str.encode('ascii', errors='ignore').str.decode('utf-8')
chunked = np.array_split(table, chunks)
# generate embeddings from list of titles
iterator = tqdm(chunked, 'embedding')
sentiments = []
story_ids = []
for _, chunk in enumerate(iterator):
sentences = chunk['title'].tolist()
ids = chunk['id'].tolist()
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
logits = model(**encoded_input.to(device)).logits
sentiment = logits.argmax(axis=1).tolist()
sentiments.append(sentiment)
story_ids.append(ids)
sentiments = np.concatenate(sentiments)
story_ids = np.concatenate(story_ids)
# save embeddings
save_to = data_dir() / 'sentiment.npy'
np.save(save_to, sentiments)
print(f"sentiments saved: {save_to}")
# save ids
save_to = data_dir() / 'sentiment_ids.npy'
np.save(save_to, story_ids)
print(f"ids saved: {save_to}")
@click.command('sentiment:load')
def load():
sentiments = np.load(paths('data') / 'sentiment.npy')
story_ids = np.load(paths('data') / 'sentiment_ids.npy')
data = pd.DataFrame(story_ids, columns=['story_id']).reset_index()
data['sentiment_id'] = sentiments
with connect() as db:
db.query("""
CREATE OR REPLACE TABLE story_sentiments AS
SELECT
data.story_id
,data.sentiment_id as class_id
,CASE WHEN data.sentiment_id = 1 THEN 'positive' ELSE 'negative' end as label
FROM data
JOIN stories s
ON s.id = data.story_id
""")