248 lines
6.5 KiB
Python
248 lines
6.5 KiB
Python
import click
|
|
from tqdm import tqdm
|
|
import torch
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
from transformers import BertTokenizer
|
|
from model import BertForMultiLabelClassification
|
|
from data import connect, data_dir
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
|
|
def data():
|
|
|
|
# load data
|
|
DB = connect()
|
|
table = DB.sql("""
|
|
SELECT
|
|
id,
|
|
title
|
|
FROM stories
|
|
WHERE id NOT IN (
|
|
SELECT
|
|
DISTINCT story_id
|
|
FROM story_emotions
|
|
)
|
|
ORDER BY id DESC
|
|
""").df()
|
|
DB.close()
|
|
|
|
return table
|
|
|
|
@click.command("emotion:create-table")
|
|
def create_table():
|
|
"""create the table to hold the title id and labels."""
|
|
DB = connect()
|
|
table = "story_emotions"
|
|
DB.execute("""
|
|
CREATE OR REPLACE TABLE {table}
|
|
(
|
|
story_id BIGINT,
|
|
label TEXT,
|
|
score REAL
|
|
)
|
|
""")
|
|
DB.close()
|
|
print(f"\"{table}\" created")
|
|
|
|
@click.command("emotion:extract")
|
|
@click.option('-c', '--chunks', type=int, default=5000, show_default=True)
|
|
def extract(chunks):
|
|
"""extract emotion class labels from titles and put them in the db"""
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
tokenizer = BertTokenizer.from_pretrained("monologg/bert-base-cased-goemotions-original")
|
|
model = BertForMultiLabelClassification.from_pretrained("monologg/bert-base-cased-goemotions-original")
|
|
model.to(device)
|
|
|
|
table = data()
|
|
chunked = np.array_split(table.to_numpy(), chunks)
|
|
for part in tqdm(chunked):
|
|
ids = [x[0] for x in part]
|
|
docs = [x[1] for x in part]
|
|
tokens = tokenizer(docs, add_special_tokens = True, truncation = True, padding = "max_length", max_length=92, return_attention_mask = True, return_tensors = "pt")
|
|
tokens = tokens.to(device)
|
|
results = run(model, tokens, ids)
|
|
df = pd.DataFrame(results)
|
|
DB = connect()
|
|
DB.execute('INSERT INTO story_emotions SELECT * FROM df')
|
|
DB.close()
|
|
|
|
def run(model, tokens, ids):
|
|
threshold = 0.1
|
|
with torch.no_grad():
|
|
outputs = model(**tokens)[0].to('cpu').detach().numpy()
|
|
scores = 1 / (1 + np.exp(-outputs)) # Sigmoid
|
|
results = []
|
|
for i, item in enumerate(scores):
|
|
for idx, s in enumerate(item):
|
|
if s > threshold:
|
|
results.append({"story_id": ids[i], "label" : model.config.id2label[idx], "score": s})
|
|
return results
|
|
|
|
@click.command("emotion:normalize")
|
|
def normalize():
|
|
"""normalize the emotion tables."""
|
|
DB = connect()
|
|
DB.sql("""
|
|
CREATE OR REPLACE TABLE emotions AS
|
|
SELECT
|
|
row_number() over() as id
|
|
,e.label
|
|
,COUNT(1) AS stories
|
|
FROM story_emotions e
|
|
JOIN stories s
|
|
ON s.id = e.story_id
|
|
-- WHERE YEAR(s.published_at) < 2022
|
|
GROUP BY e.label
|
|
HAVING stories > 1000
|
|
ORDER BY stories DESC
|
|
""")
|
|
DB.sql("""
|
|
ALTER TABLE story_emotions
|
|
ADD COLUMN emotion_id int64
|
|
""")
|
|
DB.sql("""
|
|
UPDATE story_emotions
|
|
SET emotion_id = emotions.id
|
|
FROM emotions
|
|
WHERE emotions.label = story_emotions.label
|
|
""")
|
|
DB.sql("""
|
|
ALTER TABLE story_emotions
|
|
DROP COLUMN label
|
|
""")
|
|
|
|
DB.sql("""
|
|
SELECT
|
|
row_number() over() as id
|
|
,e.label
|
|
,COUNT(1) AS stories
|
|
FROM story_emotions e
|
|
JOIN stories s
|
|
ON s.id = e.story_id
|
|
-- WHERE YEAR(s.published_at) < 2022
|
|
GROUP BY e.label
|
|
HAVING stories > 1000
|
|
ORDER BY stories DESC
|
|
""")
|
|
DB.close()
|
|
|
|
@click.command("emotion:analyze")
|
|
def analyze():
|
|
"""plot and group emotional labels"""
|
|
DB = connect()
|
|
DB.sql("""
|
|
WITH grouped as (
|
|
SELECT
|
|
YEAR(s.published_at) as year
|
|
,e.label
|
|
,COUNT(1) AS stories
|
|
FROM story_emotions e
|
|
JOIN stories s
|
|
ON s.id = e.story_id
|
|
WHERE YEAR(s.published_at) < 2022
|
|
AND label = 'annoyance'
|
|
GROUP BY
|
|
YEAR(s.published_at)
|
|
,e.label
|
|
), total AS (
|
|
SELECT
|
|
e.label
|
|
,count(1) as total
|
|
FROM grouped s
|
|
JOIN story_emotions e
|
|
ON e.label = s.label
|
|
GROUP BY
|
|
e.label
|
|
)
|
|
SELECT
|
|
g.year
|
|
,g.label
|
|
,100 * (g.stories / CAST(t.total AS float)) AS frac
|
|
FROM grouped g
|
|
JOIN total t
|
|
ON t.label = g.label
|
|
ORDER BY g.label, g.year
|
|
""")
|
|
DB.close()
|
|
|
|
sns.lineplot(x=df['year'], y=df['frac'], hue=df['label'])
|
|
plt.show()
|
|
|
|
def debug():
|
|
from transformers import pipeline
|
|
|
|
# load data
|
|
DB = connect()
|
|
table = DB.sql("""
|
|
SELECT
|
|
id,
|
|
title
|
|
FROM stories
|
|
ORDER BY id DESC
|
|
""").df()
|
|
DB.close()
|
|
|
|
classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
|
|
|
|
chunks = 5000
|
|
chunked = np.array_split(table, chunks)
|
|
labels = []
|
|
ids = []
|
|
for chunk in tqdm(chunked):
|
|
sentences = chunk['title'].tolist()
|
|
label_ids = chunk['id'].tolist()
|
|
with torch.no_grad():
|
|
emotions = classifier(sentences)
|
|
labels.append(emotions)
|
|
ids.append(label_ids)
|
|
out = pd.DataFrame(np.concatenate(labels).tolist())
|
|
out_ids = pd.DataFrame(np.concatenate(ids).tolist(), columns=['story_id'])
|
|
out = pd.concat([out_ids, out], axis=1)
|
|
|
|
DB = connect()
|
|
DB.sql("""
|
|
CREATE OR REPLACE TABLE story_emotions AS
|
|
SELECT
|
|
story_id
|
|
,label
|
|
,score
|
|
FROM out
|
|
""")
|
|
DB.sql("""
|
|
CREATE OR REPLACE TABLE emotions AS
|
|
SELECT
|
|
row_number() over() as id
|
|
,label
|
|
,count(1) as stories
|
|
FROM story_emotions
|
|
GROUP BY
|
|
label
|
|
""")
|
|
DB.sql("""
|
|
ALTER TABLE story_emotions add emotion_id bigint
|
|
""")
|
|
DB.sql("""
|
|
UPDATE story_emotions
|
|
SET emotion_id = emotions.id
|
|
FROM emotions
|
|
WHERE story_emotions.label = emotions.label
|
|
""")
|
|
DB.sql("""
|
|
ALTER TABLE story_emotions drop column label
|
|
""")
|
|
DB.sql("""
|
|
select
|
|
*
|
|
from emotions
|
|
""")
|
|
DB.sql("""
|
|
select
|
|
* from story_emotions
|
|
limit 4
|
|
""")
|
|
DB.close()
|
|
|
|
out.to_csv(data_dir() / 'emotions.csv', sep="|")
|