From 297aeec32de3728b419f1dfc77f11920ec89edea Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 12 Apr 2023 14:20:26 -0700 Subject: [PATCH] add loading csv data to database. --- src/scrape.py | 38 ++++++++++++++++++++++++++++++++++---- src/word.py | 17 +++++++++-------- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/src/scrape.py b/src/scrape.py index bff666f..65d2c30 100644 --- a/src/scrape.py +++ b/src/scrape.py @@ -4,7 +4,7 @@ import requests from pathlib import Path import click from tqdm import tqdm -from data import data_dir +from data import data_dir, connect from lxml import etree import pandas as pd @@ -12,6 +12,29 @@ import pandas as pd def cli(): ... +@cli.command() +@click.option('--directory', type=Path, default=data_dir()) +@click.option('--database', type=Path, default=data_dir() / "stories.duckdb") +def load(directory, database): + stories = directory / "stories.csv" + related = directory / "related.csv" + db = connect() + + db.sql(f""" + CREATE TABLE stories AS + SELECT + * + FROM read_csv_auto('{stories}') + """) + + db.sql(f""" + CREATE TABLE related_stories AS + SELECT + * + FROM read_csv_auto('{related}') + """) + db.close() + @cli.command() @click.option('-o', 'output_dir', type=Path, default=data_dir() / "memeorandum") def download(output_dir): @@ -20,7 +43,8 @@ def download(output_dir): end = date.today() dates = [] while cur <= end: - dates.append(cur) + if not (output_dir / f"{cur.strftime('%y-%m-%d')}.html").exists(): + dates.append(cur) cur = cur + day date_iter = tqdm(dates, postfix="test") for i in date_iter: @@ -51,6 +75,9 @@ def parse(directory, output_dir): # tree = etree.parse(str(page), parser) tree = etree.parse(str(page), parser) root = tree.getroot() + if not root: + print(f"error opening {page}") + continue items = root.xpath("//div[contains(@class, 'item')]") for item in items: @@ -64,8 +91,11 @@ def parse(directory, output_dir): else: author = '' out['author'] = author - url = citation[0].getchildren()[0].get('href') - publisher = citation[0].getchildren()[0].text + try: + url = citation[0].getchildren()[0].get('href') + publisher = citation[0].getchildren()[0].text + except IndexError as e: + print(f"error with citation url: {page}") out['publisher'] = publisher out['publisher_url'] = url title = item.xpath('.//strong/a')[0].text diff --git a/src/word.py b/src/word.py index 88a35ef..ec36f91 100644 --- a/src/word.py +++ b/src/word.py @@ -14,20 +14,21 @@ def train(): table = from_db(Data.Titles) n_classes = 10 tokenizer = AutoTokenizer.from_pretrained("roberta-base") - - # create tokens, padding to max width - tokens = tokenizer(table['title'].apply(str).to_list(), add_special_tokens = True, truncation = True, padding = "max_length", return_attention_mask = True, return_tensors = "pt") - pred_y = outputs[:, 0, :] - model = RobertaModel.from_pretrained("roberta-base") - pred_y = model(**inputs) - outputs = model(**tokens) + + def get_embeddings(titles): + # create tokens, padding to max width + tokens = tokenizer(titles, add_special_tokens = True, truncation = True, padding = "max_length", return_attention_mask = True, return_tensors = "pt") + outputs = model(**tokens) + return outputs.last_hidden_state[:, 0, :] + + titles = table['title'].apply(str).to_list()[:10] + get_embeddings(titles) # linear = torch.nn.Linear(model.config.hidden_size, n_classes) # act = torch.nn.Sigmoid() # model = Model() - pred_y.last_hidden_state[:, 0, :].shape classes = act(linear(pred_y.last_hidden_state[:, 0, :])).detach() @cli.command()