add loading csv data to database.
This commit is contained in:
parent
feb3a4b8ed
commit
297aeec32d
|
@ -4,7 +4,7 @@ import requests
|
|||
from pathlib import Path
|
||||
import click
|
||||
from tqdm import tqdm
|
||||
from data import data_dir
|
||||
from data import data_dir, connect
|
||||
from lxml import etree
|
||||
import pandas as pd
|
||||
|
||||
|
@ -12,6 +12,29 @@ import pandas as pd
|
|||
def cli():
|
||||
...
|
||||
|
||||
@cli.command()
|
||||
@click.option('--directory', type=Path, default=data_dir())
|
||||
@click.option('--database', type=Path, default=data_dir() / "stories.duckdb")
|
||||
def load(directory, database):
|
||||
stories = directory / "stories.csv"
|
||||
related = directory / "related.csv"
|
||||
db = connect()
|
||||
|
||||
db.sql(f"""
|
||||
CREATE TABLE stories AS
|
||||
SELECT
|
||||
*
|
||||
FROM read_csv_auto('{stories}')
|
||||
""")
|
||||
|
||||
db.sql(f"""
|
||||
CREATE TABLE related_stories AS
|
||||
SELECT
|
||||
*
|
||||
FROM read_csv_auto('{related}')
|
||||
""")
|
||||
db.close()
|
||||
|
||||
@cli.command()
|
||||
@click.option('-o', 'output_dir', type=Path, default=data_dir() / "memeorandum")
|
||||
def download(output_dir):
|
||||
|
@ -20,7 +43,8 @@ def download(output_dir):
|
|||
end = date.today()
|
||||
dates = []
|
||||
while cur <= end:
|
||||
dates.append(cur)
|
||||
if not (output_dir / f"{cur.strftime('%y-%m-%d')}.html").exists():
|
||||
dates.append(cur)
|
||||
cur = cur + day
|
||||
date_iter = tqdm(dates, postfix="test")
|
||||
for i in date_iter:
|
||||
|
@ -51,6 +75,9 @@ def parse(directory, output_dir):
|
|||
# tree = etree.parse(str(page), parser)
|
||||
tree = etree.parse(str(page), parser)
|
||||
root = tree.getroot()
|
||||
if not root:
|
||||
print(f"error opening {page}")
|
||||
continue
|
||||
items = root.xpath("//div[contains(@class, 'item')]")
|
||||
|
||||
for item in items:
|
||||
|
@ -64,8 +91,11 @@ def parse(directory, output_dir):
|
|||
else:
|
||||
author = ''
|
||||
out['author'] = author
|
||||
url = citation[0].getchildren()[0].get('href')
|
||||
publisher = citation[0].getchildren()[0].text
|
||||
try:
|
||||
url = citation[0].getchildren()[0].get('href')
|
||||
publisher = citation[0].getchildren()[0].text
|
||||
except IndexError as e:
|
||||
print(f"error with citation url: {page}")
|
||||
out['publisher'] = publisher
|
||||
out['publisher_url'] = url
|
||||
title = item.xpath('.//strong/a')[0].text
|
||||
|
|
17
src/word.py
17
src/word.py
|
@ -14,20 +14,21 @@ def train():
|
|||
table = from_db(Data.Titles)
|
||||
n_classes = 10
|
||||
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
|
||||
|
||||
# create tokens, padding to max width
|
||||
tokens = tokenizer(table['title'].apply(str).to_list(), add_special_tokens = True, truncation = True, padding = "max_length", return_attention_mask = True, return_tensors = "pt")
|
||||
pred_y = outputs[:, 0, :]
|
||||
|
||||
model = RobertaModel.from_pretrained("roberta-base")
|
||||
pred_y = model(**inputs)
|
||||
outputs = model(**tokens)
|
||||
|
||||
def get_embeddings(titles):
|
||||
# create tokens, padding to max width
|
||||
tokens = tokenizer(titles, add_special_tokens = True, truncation = True, padding = "max_length", return_attention_mask = True, return_tensors = "pt")
|
||||
outputs = model(**tokens)
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
titles = table['title'].apply(str).to_list()[:10]
|
||||
get_embeddings(titles)
|
||||
|
||||
# linear = torch.nn.Linear(model.config.hidden_size, n_classes)
|
||||
# act = torch.nn.Sigmoid()
|
||||
|
||||
# model = Model()
|
||||
pred_y.last_hidden_state[:, 0, :].shape
|
||||
classes = act(linear(pred_y.last_hidden_state[:, 0, :])).detach()
|
||||
|
||||
@cli.command()
|
||||
|
|
Loading…
Reference in New Issue