wwu-577/src/emotion.py

248 lines
6.5 KiB
Python

import click
from tqdm import tqdm
import torch
import pandas as pd
import numpy as np
from transformers import BertTokenizer
from model import BertForMultiLabelClassification
from data import connect, data_dir
import seaborn as sns
import matplotlib.pyplot as plt
def data():
# load data
DB = connect()
table = DB.sql("""
SELECT
id,
title
FROM stories
WHERE id NOT IN (
SELECT
DISTINCT story_id
FROM story_emotions
)
ORDER BY id DESC
""").df()
DB.close()
return table
@click.command("emotion:create-table")
def create_table():
"""create the table to hold the title id and labels."""
DB = connect()
table = "story_emotions"
DB.execute("""
CREATE OR REPLACE TABLE {table}
(
story_id BIGINT,
label TEXT,
score REAL
)
""")
DB.close()
print(f"\"{table}\" created")
@click.command("emotion:extract")
@click.option('-c', '--chunks', type=int, default=5000, show_default=True)
def extract(chunks):
"""extract emotion class labels from titles and put them in the db"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = BertTokenizer.from_pretrained("monologg/bert-base-cased-goemotions-original")
model = BertForMultiLabelClassification.from_pretrained("monologg/bert-base-cased-goemotions-original")
model.to(device)
table = data()
chunked = np.array_split(table.to_numpy(), chunks)
for part in tqdm(chunked):
ids = [x[0] for x in part]
docs = [x[1] for x in part]
tokens = tokenizer(docs, add_special_tokens = True, truncation = True, padding = "max_length", max_length=92, return_attention_mask = True, return_tensors = "pt")
tokens = tokens.to(device)
results = run(model, tokens, ids)
df = pd.DataFrame(results)
DB = connect()
DB.execute('INSERT INTO story_emotions SELECT * FROM df')
DB.close()
def run(model, tokens, ids):
threshold = 0.1
with torch.no_grad():
outputs = model(**tokens)[0].to('cpu').detach().numpy()
scores = 1 / (1 + np.exp(-outputs)) # Sigmoid
results = []
for i, item in enumerate(scores):
for idx, s in enumerate(item):
if s > threshold:
results.append({"story_id": ids[i], "label" : model.config.id2label[idx], "score": s})
return results
@click.command("emotion:normalize")
def normalize():
"""normalize the emotion tables."""
DB = connect()
DB.sql("""
CREATE OR REPLACE TABLE emotions AS
SELECT
row_number() over() as id
,e.label
,COUNT(1) AS stories
FROM story_emotions e
JOIN stories s
ON s.id = e.story_id
-- WHERE YEAR(s.published_at) < 2022
GROUP BY e.label
HAVING stories > 1000
ORDER BY stories DESC
""")
DB.sql("""
ALTER TABLE story_emotions
ADD COLUMN emotion_id int64
""")
DB.sql("""
UPDATE story_emotions
SET emotion_id = emotions.id
FROM emotions
WHERE emotions.label = story_emotions.label
""")
DB.sql("""
ALTER TABLE story_emotions
DROP COLUMN label
""")
DB.sql("""
SELECT
row_number() over() as id
,e.label
,COUNT(1) AS stories
FROM story_emotions e
JOIN stories s
ON s.id = e.story_id
-- WHERE YEAR(s.published_at) < 2022
GROUP BY e.label
HAVING stories > 1000
ORDER BY stories DESC
""")
DB.close()
@click.command("emotion:analyze")
def analyze():
"""plot and group emotional labels"""
DB = connect()
DB.sql("""
WITH grouped as (
SELECT
YEAR(s.published_at) as year
,e.label
,COUNT(1) AS stories
FROM story_emotions e
JOIN stories s
ON s.id = e.story_id
WHERE YEAR(s.published_at) < 2022
AND label = 'annoyance'
GROUP BY
YEAR(s.published_at)
,e.label
), total AS (
SELECT
e.label
,count(1) as total
FROM grouped s
JOIN story_emotions e
ON e.label = s.label
GROUP BY
e.label
)
SELECT
g.year
,g.label
,100 * (g.stories / CAST(t.total AS float)) AS frac
FROM grouped g
JOIN total t
ON t.label = g.label
ORDER BY g.label, g.year
""")
DB.close()
sns.lineplot(x=df['year'], y=df['frac'], hue=df['label'])
plt.show()
def debug():
from transformers import pipeline
# load data
DB = connect()
table = DB.sql("""
SELECT
id,
title
FROM stories
ORDER BY id DESC
""").df()
DB.close()
classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
chunks = 5000
chunked = np.array_split(table, chunks)
labels = []
ids = []
for chunk in tqdm(chunked):
sentences = chunk['title'].tolist()
label_ids = chunk['id'].tolist()
with torch.no_grad():
emotions = classifier(sentences)
labels.append(emotions)
ids.append(label_ids)
out = pd.DataFrame(np.concatenate(labels).tolist())
out_ids = pd.DataFrame(np.concatenate(ids).tolist(), columns=['story_id'])
out = pd.concat([out_ids, out], axis=1)
DB = connect()
DB.sql("""
CREATE OR REPLACE TABLE story_emotions AS
SELECT
story_id
,label
,score
FROM out
""")
DB.sql("""
CREATE OR REPLACE TABLE emotions AS
SELECT
row_number() over() as id
,label
,count(1) as stories
FROM story_emotions
GROUP BY
label
""")
DB.sql("""
ALTER TABLE story_emotions add emotion_id bigint
""")
DB.sql("""
UPDATE story_emotions
SET emotion_id = emotions.id
FROM emotions
WHERE story_emotions.label = emotions.label
""")
DB.sql("""
ALTER TABLE story_emotions drop column label
""")
DB.sql("""
select
*
from emotions
""")
DB.sql("""
select
* from story_emotions
limit 4
""")
DB.close()
out.to_csv(data_dir() / 'emotions.csv', sep="|")