Merge branch 'feature_emotion' of github.com:publicmatt/data_mining_577 into feature_emotion
This commit is contained in:
commit
4d93cf7adb
|
@ -30,4 +30,6 @@ if __name__ == "__main__":
|
||||||
cli.add_command(emotion.normalize)
|
cli.add_command(emotion.normalize)
|
||||||
cli.add_command(emotion.analyze)
|
cli.add_command(emotion.analyze)
|
||||||
cli.add_command(emotion.create_table)
|
cli.add_command(emotion.create_table)
|
||||||
|
import sentence
|
||||||
|
cli.add_command(sentence.embed)
|
||||||
cli()
|
cli()
|
||||||
|
|
|
@ -6,13 +6,14 @@ import numpy as np
|
||||||
|
|
||||||
from transformers import BertTokenizer
|
from transformers import BertTokenizer
|
||||||
from model import BertForMultiLabelClassification
|
from model import BertForMultiLabelClassification
|
||||||
from data import connect
|
from data import connect, data_dir
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.dates import DateFormatter
|
from matplotlib.dates import DateFormatter
|
||||||
import matplotlib.dates as mdates
|
import matplotlib.dates as mdates
|
||||||
|
|
||||||
def data():
|
def data():
|
||||||
|
|
||||||
# load data
|
# load data
|
||||||
DB = connect()
|
DB = connect()
|
||||||
table = DB.sql("""
|
table = DB.sql("""
|
||||||
|
@ -28,6 +29,7 @@ def data():
|
||||||
ORDER BY id DESC
|
ORDER BY id DESC
|
||||||
""").df()
|
""").df()
|
||||||
DB.close()
|
DB.close()
|
||||||
|
|
||||||
return table
|
return table
|
||||||
|
|
||||||
@click.command("emotion:create-table")
|
@click.command("emotion:create-table")
|
||||||
|
@ -298,3 +300,79 @@ def analyze():
|
||||||
|
|
||||||
sns.lineplot(x=df['year'], y=df['frac'], hue=df['label'])
|
sns.lineplot(x=df['year'], y=df['frac'], hue=df['label'])
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
def debug():
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
# load data
|
||||||
|
DB = connect()
|
||||||
|
table = DB.sql("""
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
title
|
||||||
|
FROM stories
|
||||||
|
ORDER BY id DESC
|
||||||
|
""").df()
|
||||||
|
DB.close()
|
||||||
|
|
||||||
|
classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base")
|
||||||
|
|
||||||
|
chunks = 5000
|
||||||
|
chunked = np.array_split(table, chunks)
|
||||||
|
labels = []
|
||||||
|
ids = []
|
||||||
|
for chunk in tqdm(chunked):
|
||||||
|
sentences = chunk['title'].tolist()
|
||||||
|
label_ids = chunk['id'].tolist()
|
||||||
|
with torch.no_grad():
|
||||||
|
emotions = classifier(sentences)
|
||||||
|
labels.append(emotions)
|
||||||
|
ids.append(label_ids)
|
||||||
|
out = pd.DataFrame(np.concatenate(labels).tolist())
|
||||||
|
out_ids = pd.DataFrame(np.concatenate(ids).tolist(), columns=['story_id'])
|
||||||
|
out = pd.concat([out_ids, out], axis=1)
|
||||||
|
|
||||||
|
DB = connect()
|
||||||
|
DB.sql("""
|
||||||
|
CREATE OR REPLACE TABLE story_emotions AS
|
||||||
|
SELECT
|
||||||
|
story_id
|
||||||
|
,label
|
||||||
|
,score
|
||||||
|
FROM out
|
||||||
|
""")
|
||||||
|
DB.sql("""
|
||||||
|
CREATE OR REPLACE TABLE emotions AS
|
||||||
|
SELECT
|
||||||
|
row_number() over() as id
|
||||||
|
,label
|
||||||
|
,count(1) as stories
|
||||||
|
FROM story_emotions
|
||||||
|
GROUP BY
|
||||||
|
label
|
||||||
|
""")
|
||||||
|
DB.sql("""
|
||||||
|
ALTER TABLE story_emotions add emotion_id bigint
|
||||||
|
""")
|
||||||
|
DB.sql("""
|
||||||
|
UPDATE story_emotions
|
||||||
|
SET emotion_id = emotions.id
|
||||||
|
FROM emotions
|
||||||
|
WHERE story_emotions.label = emotions.label
|
||||||
|
""")
|
||||||
|
DB.sql("""
|
||||||
|
ALTER TABLE story_emotions drop column label
|
||||||
|
""")
|
||||||
|
DB.sql("""
|
||||||
|
select
|
||||||
|
*
|
||||||
|
from emotions
|
||||||
|
""")
|
||||||
|
DB.sql("""
|
||||||
|
select
|
||||||
|
* from story_emotions
|
||||||
|
limit 4
|
||||||
|
""")
|
||||||
|
DB.close()
|
||||||
|
|
||||||
|
out.to_csv(data_dir() / 'emotions.csv', sep="|")
|
||||||
|
|
19
src/word.py
19
src/word.py
|
@ -81,4 +81,21 @@ def distance():
|
||||||
min_index = (np.argmin(distances))
|
min_index = (np.argmin(distances))
|
||||||
closest = np.unravel_index(min_index, distances.shape)
|
closest = np.unravel_index(min_index, distances.shape)
|
||||||
distances.flatten().shape
|
distances.flatten().shape
|
||||||
DB.close()
|
|
||||||
|
# path = data_dir() / 'embeddings'
|
||||||
|
# chunks = [x for x in path.iterdir() if x.match('*.npy')]
|
||||||
|
# chunks = sorted(chunks, key=lambda x: int(x.stem.split('_')[1]))
|
||||||
|
#
|
||||||
|
# data = None
|
||||||
|
# for i, f in enumerate(tqdm(chunks)):
|
||||||
|
# loaded = np.load(f)
|
||||||
|
# if data is None:
|
||||||
|
# data = loaded
|
||||||
|
# else:
|
||||||
|
# data = np.concatenate([data, loaded])
|
||||||
|
# if i > 20:
|
||||||
|
# break
|
||||||
|
#
|
||||||
|
# data.shape
|
||||||
|
#
|
||||||
|
# np.save(data, data_dir() / 'embeddings.npy')
|
||||||
|
|
Loading…
Reference in New Issue