add another emotion pretrained model.

This commit is contained in:
Matt 2023-05-08 14:13:56 -07:00
parent 3a6f97b290
commit f3d7678066
3 changed files with 99 additions and 1 deletions

View File

@ -28,4 +28,6 @@ if __name__ == "__main__":
cli.add_command(emotion.normalize)
cli.add_command(emotion.analyze)
cli.add_command(emotion.create_table)
import sentence
cli.add_command(sentence.embed)
cli()

View File

@ -6,11 +6,12 @@ import numpy as np
from transformers import BertTokenizer
from model import BertForMultiLabelClassification
from data import connect
from data import connect, data_dir
import seaborn as sns
import matplotlib.pyplot as plt
def data():
# load data
DB = connect()
table = DB.sql("""
@ -26,6 +27,7 @@ def data():
ORDER BY id DESC
""").df()
DB.close()
return table
@click.command("emotion:create-table")
@ -167,3 +169,79 @@ def analyze():
sns.lineplot(x=df['year'], y=df['frac'], hue=df['label'])
plt.show()
def debug():
from transformers import pipeline
# load data
DB = connect()
table = DB.sql("""
SELECT
id,
title
FROM stories
ORDER BY id DESC
""").df()
DB.close()
classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
chunks = 5000
chunked = np.array_split(table, chunks)
labels = []
ids = []
for chunk in tqdm(chunked):
sentences = chunk['title'].tolist()
label_ids = chunk['id'].tolist()
with torch.no_grad():
emotions = classifier(sentences)
labels.append(emotions)
ids.append(label_ids)
out = pd.DataFrame(np.concatenate(labels).tolist())
out_ids = pd.DataFrame(np.concatenate(ids).tolist(), columns=['story_id'])
out = pd.concat([out_ids, out], axis=1)
DB = connect()
DB.sql("""
CREATE OR REPLACE TABLE story_emotions AS
SELECT
story_id
,label
,score
FROM out
""")
DB.sql("""
CREATE OR REPLACE TABLE emotions AS
SELECT
row_number() over() as id
,label
,count(1) as stories
FROM story_emotions
GROUP BY
label
""")
DB.sql("""
ALTER TABLE story_emotions add emotion_id bigint
""")
DB.sql("""
UPDATE story_emotions
SET emotion_id = emotions.id
FROM emotions
WHERE story_emotions.label = emotions.label
""")
DB.sql("""
ALTER TABLE story_emotions drop column label
""")
DB.sql("""
select
*
from emotions
""")
DB.sql("""
select
* from story_emotions
limit 4
""")
DB.close()
out.to_csv(data_dir() / 'emotions.csv', sep="|")

View File

@ -81,3 +81,21 @@ def distance():
min_index = (np.argmin(distances))
closest = np.unravel_index(min_index, distances.shape)
distances.flatten().shape
# path = data_dir() / 'embeddings'
# chunks = [x for x in path.iterdir() if x.match('*.npy')]
# chunks = sorted(chunks, key=lambda x: int(x.stem.split('_')[1]))
#
# data = None
# for i, f in enumerate(tqdm(chunks)):
# loaded = np.load(f)
# if data is None:
# data = loaded
# else:
# data = np.concatenate([data, loaded])
# if i > 20:
# break
#
# data.shape
#
# np.save(data, data_dir() / 'embeddings.npy')