add another emotion pretrained model.
This commit is contained in:
parent
3a6f97b290
commit
f3d7678066
|
@ -28,4 +28,6 @@ if __name__ == "__main__":
|
|||
cli.add_command(emotion.normalize)
|
||||
cli.add_command(emotion.analyze)
|
||||
cli.add_command(emotion.create_table)
|
||||
import sentence
|
||||
cli.add_command(sentence.embed)
|
||||
cli()
|
||||
|
|
|
@ -6,11 +6,12 @@ import numpy as np
|
|||
|
||||
from transformers import BertTokenizer
|
||||
from model import BertForMultiLabelClassification
|
||||
from data import connect
|
||||
from data import connect, data_dir
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def data():
|
||||
|
||||
# load data
|
||||
DB = connect()
|
||||
table = DB.sql("""
|
||||
|
@ -26,6 +27,7 @@ def data():
|
|||
ORDER BY id DESC
|
||||
""").df()
|
||||
DB.close()
|
||||
|
||||
return table
|
||||
|
||||
@click.command("emotion:create-table")
|
||||
|
@ -167,3 +169,79 @@ def analyze():
|
|||
|
||||
sns.lineplot(x=df['year'], y=df['frac'], hue=df['label'])
|
||||
plt.show()
|
||||
|
||||
def debug():
|
||||
from transformers import pipeline
|
||||
|
||||
# load data
|
||||
DB = connect()
|
||||
table = DB.sql("""
|
||||
SELECT
|
||||
id,
|
||||
title
|
||||
FROM stories
|
||||
ORDER BY id DESC
|
||||
""").df()
|
||||
DB.close()
|
||||
|
||||
classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
|
||||
|
||||
chunks = 5000
|
||||
chunked = np.array_split(table, chunks)
|
||||
labels = []
|
||||
ids = []
|
||||
for chunk in tqdm(chunked):
|
||||
sentences = chunk['title'].tolist()
|
||||
label_ids = chunk['id'].tolist()
|
||||
with torch.no_grad():
|
||||
emotions = classifier(sentences)
|
||||
labels.append(emotions)
|
||||
ids.append(label_ids)
|
||||
out = pd.DataFrame(np.concatenate(labels).tolist())
|
||||
out_ids = pd.DataFrame(np.concatenate(ids).tolist(), columns=['story_id'])
|
||||
out = pd.concat([out_ids, out], axis=1)
|
||||
|
||||
DB = connect()
|
||||
DB.sql("""
|
||||
CREATE OR REPLACE TABLE story_emotions AS
|
||||
SELECT
|
||||
story_id
|
||||
,label
|
||||
,score
|
||||
FROM out
|
||||
""")
|
||||
DB.sql("""
|
||||
CREATE OR REPLACE TABLE emotions AS
|
||||
SELECT
|
||||
row_number() over() as id
|
||||
,label
|
||||
,count(1) as stories
|
||||
FROM story_emotions
|
||||
GROUP BY
|
||||
label
|
||||
""")
|
||||
DB.sql("""
|
||||
ALTER TABLE story_emotions add emotion_id bigint
|
||||
""")
|
||||
DB.sql("""
|
||||
UPDATE story_emotions
|
||||
SET emotion_id = emotions.id
|
||||
FROM emotions
|
||||
WHERE story_emotions.label = emotions.label
|
||||
""")
|
||||
DB.sql("""
|
||||
ALTER TABLE story_emotions drop column label
|
||||
""")
|
||||
DB.sql("""
|
||||
select
|
||||
*
|
||||
from emotions
|
||||
""")
|
||||
DB.sql("""
|
||||
select
|
||||
* from story_emotions
|
||||
limit 4
|
||||
""")
|
||||
DB.close()
|
||||
|
||||
out.to_csv(data_dir() / 'emotions.csv', sep="|")
|
||||
|
|
18
src/word.py
18
src/word.py
|
@ -81,3 +81,21 @@ def distance():
|
|||
min_index = (np.argmin(distances))
|
||||
closest = np.unravel_index(min_index, distances.shape)
|
||||
distances.flatten().shape
|
||||
|
||||
# path = data_dir() / 'embeddings'
|
||||
# chunks = [x for x in path.iterdir() if x.match('*.npy')]
|
||||
# chunks = sorted(chunks, key=lambda x: int(x.stem.split('_')[1]))
|
||||
#
|
||||
# data = None
|
||||
# for i, f in enumerate(tqdm(chunks)):
|
||||
# loaded = np.load(f)
|
||||
# if data is None:
|
||||
# data = loaded
|
||||
# else:
|
||||
# data = np.concatenate([data, loaded])
|
||||
# if i > 20:
|
||||
# break
|
||||
#
|
||||
# data.shape
|
||||
#
|
||||
# np.save(data, data_dir() / 'embeddings.npy')
|
||||
|
|
Loading…
Reference in New Issue