import click from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import torch import torch.nn.functional as F from data.main import connect, paths import numpy as np from tqdm import tqdm import pandas as pd @click.option('-c', '--chunks', type=int, default=500, show_default=True) @click.command("sentiment:extract") def extract(chunks): device = 'cuda' if torch.cuda.is_available() else 'cpu' chunks = 1000 # Load model from HuggingFace Hub tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased") model = model.to(device) # load data DB = connect() table = DB.sql(""" select id ,title from stories order by id desc """).df() DB.close() # normalize text table['title'] = table['title'].str.normalize('NFKD').str.encode('ascii', errors='ignore').str.decode('utf-8') chunked = np.array_split(table, chunks) # generate embeddings from list of titles iterator = tqdm(chunked, 'embedding') sentiments = [] story_ids = [] for _, chunk in enumerate(iterator): sentences = chunk['title'].tolist() ids = chunk['id'].tolist() # Tokenize sentences encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') # Compute token embeddings with torch.no_grad(): logits = model(**encoded_input.to(device)).logits sentiment = logits.argmax(axis=1).tolist() sentiments.append(sentiment) story_ids.append(ids) sentiments = np.concatenate(sentiments) story_ids = np.concatenate(story_ids) # save embeddings save_to = data_dir() / 'sentiment.npy' np.save(save_to, sentiments) print(f"sentiments saved: {save_to}") # save ids save_to = data_dir() / 'sentiment_ids.npy' np.save(save_to, story_ids) print(f"ids saved: {save_to}") @click.command('sentiment:load') def load(): sentiments = np.load(paths('data') / 'sentiment.npy') story_ids = np.load(paths('data') / 'sentiment_ids.npy') data = pd.DataFrame(story_ids, columns=['story_id']).reset_index() data['sentiment_id'] = sentiments with connect() as db: db.query(""" CREATE OR REPLACE TABLE story_sentiments AS SELECT data.story_id ,data.sentiment_id as class_id ,CASE WHEN data.sentiment_id = 1 THEN 'positive' ELSE 'negative' end as label FROM data JOIN stories s ON s.id = data.story_id """)