Merge branch 'feature_emotion' of github.com:publicmatt/data_mining_577 into feature_emotion

This commit is contained in:
matt 2023-05-08 14:20:00 -07:00
commit 4d93cf7adb
3 changed files with 99 additions and 2 deletions

View File

@ -30,4 +30,6 @@ if __name__ == "__main__":
cli.add_command(emotion.normalize) cli.add_command(emotion.normalize)
cli.add_command(emotion.analyze) cli.add_command(emotion.analyze)
cli.add_command(emotion.create_table) cli.add_command(emotion.create_table)
import sentence
cli.add_command(sentence.embed)
cli() cli()

View File

@ -6,13 +6,14 @@ import numpy as np
from transformers import BertTokenizer from transformers import BertTokenizer
from model import BertForMultiLabelClassification from model import BertForMultiLabelClassification
from data import connect from data import connect, data_dir
import seaborn as sns import seaborn as sns
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.dates import DateFormatter from matplotlib.dates import DateFormatter
import matplotlib.dates as mdates import matplotlib.dates as mdates
def data(): def data():
# load data # load data
DB = connect() DB = connect()
table = DB.sql(""" table = DB.sql("""
@ -28,6 +29,7 @@ def data():
ORDER BY id DESC ORDER BY id DESC
""").df() """).df()
DB.close() DB.close()
return table return table
@click.command("emotion:create-table") @click.command("emotion:create-table")
@ -298,3 +300,79 @@ def analyze():
sns.lineplot(x=df['year'], y=df['frac'], hue=df['label']) sns.lineplot(x=df['year'], y=df['frac'], hue=df['label'])
plt.show() plt.show()
def debug():
from transformers import pipeline
# load data
DB = connect()
table = DB.sql("""
SELECT
id,
title
FROM stories
ORDER BY id DESC
""").df()
DB.close()
classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
chunks = 5000
chunked = np.array_split(table, chunks)
labels = []
ids = []
for chunk in tqdm(chunked):
sentences = chunk['title'].tolist()
label_ids = chunk['id'].tolist()
with torch.no_grad():
emotions = classifier(sentences)
labels.append(emotions)
ids.append(label_ids)
out = pd.DataFrame(np.concatenate(labels).tolist())
out_ids = pd.DataFrame(np.concatenate(ids).tolist(), columns=['story_id'])
out = pd.concat([out_ids, out], axis=1)
DB = connect()
DB.sql("""
CREATE OR REPLACE TABLE story_emotions AS
SELECT
story_id
,label
,score
FROM out
""")
DB.sql("""
CREATE OR REPLACE TABLE emotions AS
SELECT
row_number() over() as id
,label
,count(1) as stories
FROM story_emotions
GROUP BY
label
""")
DB.sql("""
ALTER TABLE story_emotions add emotion_id bigint
""")
DB.sql("""
UPDATE story_emotions
SET emotion_id = emotions.id
FROM emotions
WHERE story_emotions.label = emotions.label
""")
DB.sql("""
ALTER TABLE story_emotions drop column label
""")
DB.sql("""
select
*
from emotions
""")
DB.sql("""
select
* from story_emotions
limit 4
""")
DB.close()
out.to_csv(data_dir() / 'emotions.csv', sep="|")

View File

@ -81,4 +81,21 @@ def distance():
min_index = (np.argmin(distances)) min_index = (np.argmin(distances))
closest = np.unravel_index(min_index, distances.shape) closest = np.unravel_index(min_index, distances.shape)
distances.flatten().shape distances.flatten().shape
DB.close()
# path = data_dir() / 'embeddings'
# chunks = [x for x in path.iterdir() if x.match('*.npy')]
# chunks = sorted(chunks, key=lambda x: int(x.stem.split('_')[1]))
#
# data = None
# for i, f in enumerate(tqdm(chunks)):
# loaded = np.load(f)
# if data is None:
# data = loaded
# else:
# data = np.concatenate([data, loaded])
# if i > 20:
# break
#
# data.shape
#
# np.save(data, data_dir() / 'embeddings.npy')