From f3d76780662a614ffe2ff42531c55fc695a520f3 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 8 May 2023 14:13:56 -0700 Subject: [PATCH] add another emotion pretrained model. --- src/cli.py | 2 ++ src/emotion.py | 80 +++++++++++++++++++++++++++++++++++++++++++++++++- src/word.py | 18 ++++++++++++ 3 files changed, 99 insertions(+), 1 deletion(-) diff --git a/src/cli.py b/src/cli.py index 64b93f8..b05e526 100644 --- a/src/cli.py +++ b/src/cli.py @@ -28,4 +28,6 @@ if __name__ == "__main__": cli.add_command(emotion.normalize) cli.add_command(emotion.analyze) cli.add_command(emotion.create_table) + import sentence + cli.add_command(sentence.embed) cli() diff --git a/src/emotion.py b/src/emotion.py index 363c975..777b00e 100644 --- a/src/emotion.py +++ b/src/emotion.py @@ -6,11 +6,12 @@ import numpy as np from transformers import BertTokenizer from model import BertForMultiLabelClassification -from data import connect +from data import connect, data_dir import seaborn as sns import matplotlib.pyplot as plt def data(): + # load data DB = connect() table = DB.sql(""" @@ -26,6 +27,7 @@ def data(): ORDER BY id DESC """).df() DB.close() + return table @click.command("emotion:create-table") @@ -167,3 +169,79 @@ def analyze(): sns.lineplot(x=df['year'], y=df['frac'], hue=df['label']) plt.show() + +def debug(): + from transformers import pipeline + + # load data + DB = connect() + table = DB.sql(""" + SELECT + id, + title + FROM stories + ORDER BY id DESC + """).df() + DB.close() + + classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base") + + chunks = 5000 + chunked = np.array_split(table, chunks) + labels = [] + ids = [] + for chunk in tqdm(chunked): + sentences = chunk['title'].tolist() + label_ids = chunk['id'].tolist() + with torch.no_grad(): + emotions = classifier(sentences) + labels.append(emotions) + ids.append(label_ids) + out = pd.DataFrame(np.concatenate(labels).tolist()) + out_ids = pd.DataFrame(np.concatenate(ids).tolist(), columns=['story_id']) + out = pd.concat([out_ids, out], axis=1) + + DB = connect() + DB.sql(""" + CREATE OR REPLACE TABLE story_emotions AS + SELECT + story_id + ,label + ,score + FROM out + """) + DB.sql(""" + CREATE OR REPLACE TABLE emotions AS + SELECT + row_number() over() as id + ,label + ,count(1) as stories + FROM story_emotions + GROUP BY + label + """) + DB.sql(""" + ALTER TABLE story_emotions add emotion_id bigint + """) + DB.sql(""" + UPDATE story_emotions + SET emotion_id = emotions.id + FROM emotions + WHERE story_emotions.label = emotions.label + """) + DB.sql(""" + ALTER TABLE story_emotions drop column label + """) + DB.sql(""" + select + * + from emotions + """) + DB.sql(""" + select + * from story_emotions + limit 4 + """) + DB.close() + + out.to_csv(data_dir() / 'emotions.csv', sep="|") diff --git a/src/word.py b/src/word.py index 980787b..dc408d5 100644 --- a/src/word.py +++ b/src/word.py @@ -81,3 +81,21 @@ def distance(): min_index = (np.argmin(distances)) closest = np.unravel_index(min_index, distances.shape) distances.flatten().shape + +# path = data_dir() / 'embeddings' +# chunks = [x for x in path.iterdir() if x.match('*.npy')] +# chunks = sorted(chunks, key=lambda x: int(x.stem.split('_')[1])) +# +# data = None +# for i, f in enumerate(tqdm(chunks)): +# loaded = np.load(f) +# if data is None: +# data = loaded +# else: +# data = np.concatenate([data, loaded]) +# if i > 20: +# break +# +# data.shape +# +# np.save(data, data_dir() / 'embeddings.npy')