add function to extract emotional labels of titles.
This commit is contained in:
parent
c38a5455a8
commit
3a6f97b290
|
@ -23,4 +23,9 @@ if __name__ == "__main__":
|
|||
cli.add_command(mine.embeddings)
|
||||
cli.add_command(mine.cluster)
|
||||
cli.add_command(mine.plot)
|
||||
import emotion
|
||||
cli.add_command(emotion.extract)
|
||||
cli.add_command(emotion.normalize)
|
||||
cli.add_command(emotion.analyze)
|
||||
cli.add_command(emotion.create_table)
|
||||
cli()
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
import click
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from transformers import BertTokenizer
|
||||
from model import BertForMultiLabelClassification
|
||||
from data import connect
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def data():
|
||||
# load data
|
||||
DB = connect()
|
||||
table = DB.sql("""
|
||||
SELECT
|
||||
id,
|
||||
title
|
||||
FROM stories
|
||||
WHERE id NOT IN (
|
||||
SELECT
|
||||
DISTINCT story_id
|
||||
FROM story_emotions
|
||||
)
|
||||
ORDER BY id DESC
|
||||
""").df()
|
||||
DB.close()
|
||||
return table
|
||||
|
||||
@click.command("emotion:create-table")
|
||||
def create_table():
|
||||
"""create the table to hold the title id and labels."""
|
||||
DB = connect()
|
||||
table = "story_emotions"
|
||||
DB.execute("""
|
||||
CREATE OR REPLACE TABLE {table}
|
||||
(
|
||||
story_id BIGINT,
|
||||
label TEXT,
|
||||
score REAL
|
||||
)
|
||||
""")
|
||||
DB.close()
|
||||
print(f"\"{table}\" created")
|
||||
|
||||
@click.command("emotion:extract")
|
||||
@click.option('-c', '--chunks', type=int, default=5000, show_default=True)
|
||||
def extract(chunks):
|
||||
"""extract emotion class labels from titles and put them in the db"""
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
tokenizer = BertTokenizer.from_pretrained("monologg/bert-base-cased-goemotions-original")
|
||||
model = BertForMultiLabelClassification.from_pretrained("monologg/bert-base-cased-goemotions-original")
|
||||
model.to(device)
|
||||
|
||||
table = data()
|
||||
chunked = np.array_split(table.to_numpy(), chunks)
|
||||
for part in tqdm(chunked):
|
||||
ids = [x[0] for x in part]
|
||||
docs = [x[1] for x in part]
|
||||
tokens = tokenizer(docs, add_special_tokens = True, truncation = True, padding = "max_length", max_length=92, return_attention_mask = True, return_tensors = "pt")
|
||||
tokens = tokens.to(device)
|
||||
results = run(model, tokens, ids)
|
||||
df = pd.DataFrame(results)
|
||||
DB = connect()
|
||||
DB.execute('INSERT INTO story_emotions SELECT * FROM df')
|
||||
DB.close()
|
||||
|
||||
def run(model, tokens, ids):
|
||||
threshold = 0.1
|
||||
with torch.no_grad():
|
||||
outputs = model(**tokens)[0].to('cpu').detach().numpy()
|
||||
scores = 1 / (1 + np.exp(-outputs)) # Sigmoid
|
||||
results = []
|
||||
for i, item in enumerate(scores):
|
||||
for idx, s in enumerate(item):
|
||||
if s > threshold:
|
||||
results.append({"story_id": ids[i], "label" : model.config.id2label[idx], "score": s})
|
||||
return results
|
||||
|
||||
@click.command("emotion:normalize")
|
||||
def normalize():
|
||||
"""normalize the emotion tables."""
|
||||
DB = connect()
|
||||
DB.sql("""
|
||||
CREATE OR REPLACE TABLE emotions AS
|
||||
SELECT
|
||||
row_number() over() as id
|
||||
,e.label
|
||||
,COUNT(1) AS stories
|
||||
FROM story_emotions e
|
||||
JOIN stories s
|
||||
ON s.id = e.story_id
|
||||
-- WHERE YEAR(s.published_at) < 2022
|
||||
GROUP BY e.label
|
||||
HAVING stories > 1000
|
||||
ORDER BY stories DESC
|
||||
""")
|
||||
DB.sql("""
|
||||
ALTER TABLE story_emotions
|
||||
ADD COLUMN emotion_id int64
|
||||
""")
|
||||
DB.sql("""
|
||||
UPDATE story_emotions
|
||||
SET emotion_id = emotions.id
|
||||
FROM emotions
|
||||
WHERE emotions.label = story_emotions.label
|
||||
""")
|
||||
DB.sql("""
|
||||
ALTER TABLE story_emotions
|
||||
DROP COLUMN label
|
||||
""")
|
||||
|
||||
DB.sql("""
|
||||
SELECT
|
||||
row_number() over() as id
|
||||
,e.label
|
||||
,COUNT(1) AS stories
|
||||
FROM story_emotions e
|
||||
JOIN stories s
|
||||
ON s.id = e.story_id
|
||||
-- WHERE YEAR(s.published_at) < 2022
|
||||
GROUP BY e.label
|
||||
HAVING stories > 1000
|
||||
ORDER BY stories DESC
|
||||
""")
|
||||
DB.close()
|
||||
|
||||
@click.command("emotion:analyze")
|
||||
def analyze():
|
||||
"""plot and group emotional labels"""
|
||||
DB = connect()
|
||||
DB.sql("""
|
||||
WITH grouped as (
|
||||
SELECT
|
||||
YEAR(s.published_at) as year
|
||||
,e.label
|
||||
,COUNT(1) AS stories
|
||||
FROM story_emotions e
|
||||
JOIN stories s
|
||||
ON s.id = e.story_id
|
||||
WHERE YEAR(s.published_at) < 2022
|
||||
AND label = 'annoyance'
|
||||
GROUP BY
|
||||
YEAR(s.published_at)
|
||||
,e.label
|
||||
), total AS (
|
||||
SELECT
|
||||
e.label
|
||||
,count(1) as total
|
||||
FROM grouped s
|
||||
JOIN story_emotions e
|
||||
ON e.label = s.label
|
||||
GROUP BY
|
||||
e.label
|
||||
)
|
||||
SELECT
|
||||
g.year
|
||||
,g.label
|
||||
,100 * (g.stories / CAST(t.total AS float)) AS frac
|
||||
FROM grouped g
|
||||
JOIN total t
|
||||
ON t.label = g.label
|
||||
ORDER BY g.label, g.year
|
||||
""")
|
||||
DB.close()
|
||||
|
||||
sns.lineplot(x=df['year'], y=df['frac'], hue=df['label'])
|
||||
plt.show()
|
47
src/model.py
47
src/model.py
|
@ -14,3 +14,50 @@ class Model(nn.Module):
|
|||
outs = self.bert(**x)
|
||||
outs = self.act(self.linear(outs.last_hidden_state))
|
||||
return outs
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import BertPreTrainedModel, BertModel
|
||||
|
||||
|
||||
class BertForMultiLabelClassification(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.bert = BertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
||||
self.loss_fct = nn.BCEWithLogitsLoss()
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
):
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_fct(logits, labels)
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
|
35
src/word.py
35
src/word.py
|
@ -31,13 +31,14 @@ def train():
|
|||
|
||||
@click.command(name="word:embed")
|
||||
@click.option('-c', '--chunks', type=int, default=5000, show_default=True)
|
||||
@click.option('--embedding_dest', help="path to save embeddings as np array", type=Path, default=Path(data_dir() / 'sequence_embeddings.npy'), show_default=True)
|
||||
@click.option('--token_dest', help="path to save tokens as np array", type=Path, default=Path(data_dir() / 'sequence_tokens.npy'), show_default=True)
|
||||
def embed(chunks, embedding_dest, token_dest):
|
||||
@click.option('--embedding_dir', help="path to save embeddings as np array", type=Path, default=Path(data_dir() / 'embeddings'), show_default=True)
|
||||
@click.option('--token_dir', help="path to save tokens as np array", type=Path, default=Path(data_dir() / 'tokens'), show_default=True)
|
||||
@click.option('--device', help="device to process data on", type=str, default="cuda:0", show_default=True)
|
||||
def embed(chunks, embedding_dir, token_dir, device):
|
||||
""" given titles, generate tokens and word embeddings and saves to disk """
|
||||
|
||||
# init models
|
||||
device = torch.device('cuda:0')
|
||||
device = torch.device(device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
|
||||
model = RobertaModel.from_pretrained("roberta-base")
|
||||
model.to(device)
|
||||
|
@ -56,29 +57,21 @@ def embed(chunks, embedding_dest, token_dest):
|
|||
table['title'] = table['title'].str.normalize('NFKD').str.encode('ascii', errors='ignore').str.decode('utf-8')
|
||||
|
||||
# generate embeddings from list of titles
|
||||
def get_embeddings(titles):
|
||||
chunks = np.array_split(table['title'].to_numpy(), chunks)
|
||||
chunk_iter = tqdm(chunks, 'embedding')
|
||||
for i, chunk in enumerate(chunk_iter):
|
||||
# create tokens, padding to max width
|
||||
tokens = tokenizer(titles, add_special_tokens = True, truncation = True, padding = "max_length", max_length=92, return_attention_mask = True, return_tensors = "pt")
|
||||
tokens = tokenizer(chunk.tolist(), add_special_tokens = True, truncation = True, padding = "max_length", max_length=92, return_attention_mask = True, return_tensors = "pt")
|
||||
tokens = tokens.to(device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**tokens)
|
||||
return tokens.to(torch.device('cpu')), outputs.last_hidden_state.to(torch.device('cpu'))
|
||||
|
||||
tokens = []
|
||||
embeddings = []
|
||||
chunks = np.array_split(table['title'].to_numpy(), chunks)
|
||||
chunk_iter = tqdm(chunks, 'embedding')
|
||||
for chunk in chunk_iter:
|
||||
data = chunk.tolist()
|
||||
token, embedding = get_embeddings(data)
|
||||
arr = embedding.detach().numpy()
|
||||
embeddings.append(arr)
|
||||
tokens.append(token)
|
||||
# to disk
|
||||
hidden = outputs.last_hidden_state.to(torch.device('cpu')).detach().numpy()
|
||||
np.save(embedding_dir / f"embedding_{i}.npy", hidden)
|
||||
|
||||
embeddings = np.concatenate(embeddings)
|
||||
tokens = np.concatenate(tokens)
|
||||
np.save(embedding_dest, embeddings)
|
||||
np.save(token_dest, tokens)
|
||||
tokens = tokens.to(torch.device('cpu'))
|
||||
np.save(token_dir / f"token_{i}.npy", tokens)
|
||||
|
||||
@click.command(name="word:distance")
|
||||
def distance():
|
||||
|
|
Loading…
Reference in New Issue