This commit is contained in:
matt 2023-04-11 13:27:56 -07:00
commit b9c63414a0
8 changed files with 160 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*.csv
*.swp

BIN
docs/title_len.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

65
src/cli.py Normal file
View File

@ -0,0 +1,65 @@
import polars as pl
import duckdb
import toml
import os
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt
from enum import Enum, auto
import click
DATA_DIR = Path(os.environ['DATA_MINING_DATA_DIR'])
APP_DIR = Path(os.environ['DATA_MINING_APP_DIR'])
db = duckdb.connect(str(DATA_DIR / 'project.duckdb'))
@click.group()
def cli():
...
class PlotName(str, Enum):
TitleLength = "title_len"
OutletStories = "outlet_stories"
@cli.command()
@click.option('-n', '--name', required=True, type=click.Choice(PlotName))
@click.option('-o', '--output', required=False, type=click.Path())
def plot(name: PlotName, output: Path):
output = output if output else APP_DIR / f'docs/{name}.png'
if name == PlotName.TitleLength:
fig, ax = plt.subplots(1,1)
data = db.sql("""
select
length(title) as len
from stories
""").df()
sns.histplot(x=data['len'], bins=50, ax=ax[0])
ax[0].set(ylabel="count", xlabel="title length")
elif name == PlotName.OutletStories:
data = db.sql("""
with cte as (
select
count(1) as stories
from stories
group by outlet
)
select
row_number() over(order by stories desc) as id
,log(stories) as log_count
from cte
""").df()
fig, ax = plt.subplots(1,1)
sns.lineplot(x=data['id'], y=data['log_count'], ax=ax)
from matplotlib.ticker import ScalarFormatter
ax.set(yscale='log', xlabel="outlet", ylabel="log(count of stories)", majorformater=ScalarFormatter)
plt.show()
else:
raise NotImplementedError("option unrecognized")
plt.savefig(output)
if __name__ == "__main__":
cli()

21
src/data.py Normal file
View File

@ -0,0 +1,21 @@
import os
from pathlib import Path
import duckdb
from enum import Enum
class Data(str, Enum):
Titles = 'titles'
def from_db(t: Data):
DATA_DIR = Path(os.environ['DATA_MINING_DATA_DIR'])
# APP_DIR = Path(os.environ['DATA_MINING_APP_DIR'])
DB = duckdb.connect(str(DATA_DIR / 'project.duckdb'))
if t == Data.Titles:
table = DB.sql("""
select
distinct
title
from stories
limit 100
""").df()
return table

8
src/lib.py Normal file
View File

@ -0,0 +1,8 @@
import sklearn
import polars as pl
import toml
from pathlib import Path
config = toml.load('/home/user/577/repo/config.toml')
app_dir = Path(config.get('app').get('path'))
df = pl.read_csv(app_dir / "data/articles.csv")

16
src/model.py Normal file
View File

@ -0,0 +1,16 @@
from transformers import AutoTokenizer, RobertaModel
import torch
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.n_classes = 10
self.bert = RobertaModel.from_pretrained("roberta-base")
self.linear = torch.nn.Linear(self.bert.config.hidden_size, self.n_classes)
self.act = torch.nn.Sigmoid()
def forward(self, x):
outs = self.bert(**x)
outs = self.act(self.linear(outs.last_hidden_state))
return outs

5
src/nearest_neighbor.py Normal file
View File

@ -0,0 +1,5 @@
import pandas as pd
import math
df = pd.read_csv('/tmp/attr.csv')
((((df.left - 9.1) ** 2) + ((df.right - 11.0) ** 2)) ** 0.5).sort_values()

43
src/word.py Normal file
View File

@ -0,0 +1,43 @@
import click
from scipy.spatial import distance
from transformers import AutoTokenizer, RobertaModel
import numpy as np
from model import Model
from data import Data, from_db
@click.group()
def cli():
...
@cli.command()
def train():
table = from_db(Data.Titles)
n_classes = 10
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
# create tokens, padding to max width
tokens = tokenizer(table['title'].apply(str).to_list(), add_special_tokens = True, truncation = True, padding = "max_length", return_attention_mask = True, return_tensors = "pt")
pred_y = outputs[:, 0, :]
model = RobertaModel.from_pretrained("roberta-base")
pred_y = model(**inputs)
outputs = model(**tokens)
# linear = torch.nn.Linear(model.config.hidden_size, n_classes)
# act = torch.nn.Sigmoid()
# model = Model()
pred_y.last_hidden_state[:, 0, :].shape
classes = act(linear(pred_y.last_hidden_state[:, 0, :])).detach()
@cli.command()
def distance():
distances = distance.cdist(classes, classes, 'euclidean')
np.fill_diagonal(distances, np.inf)
min_index = (np.argmin(distances))
closest = np.unravel_index(min_index, distances.shape)
distances.flatten().shape
if __name__ == "__main__":
cli()