diff --git a/src/cli.py b/src/cli.py index 4391de0..ba265fd 100644 --- a/src/cli.py +++ b/src/cli.py @@ -1,66 +1,9 @@ -import polars as pl -import duckdb -import toml -import os -from pathlib import Path -import seaborn as sns -import matplotlib.pyplot as plt -from enum import Enum, auto import click -DATA_DIR = Path(os.environ['DATA_MINING_DATA_DIR']) -APP_DIR = Path(os.environ['DATA_MINING_APP_DIR']) - -db = duckdb.connect(str(DATA_DIR / 'project.duckdb')) - @click.group() def cli(): ... - -class PlotName(str, Enum): - TitleLength = "title_len" - OutletStories = "outlet_stories" - -@cli.command() -@click.option('-n', '--name', required=True, type=click.Choice(PlotName)) -@click.option('-o', '--output', required=False, type=click.Path()) -def plot(name: PlotName, output: Path): - output = output if output else APP_DIR / f'docs/{name}.png' - if name == PlotName.TitleLength: - fig, ax = plt.subplots(1,1) - data = db.sql(""" - select - length(title) as len - from stories - """).df() - sns.histplot(x=data['len'], bins=50, ax=ax[0]) - ax[0].set(ylabel="count", xlabel="title length") - elif name == PlotName.OutletStories: - - data = db.sql(""" - with cte as ( - select - count(1) as stories - from stories - group by outlet - ) - select - row_number() over(order by stories desc) as id - ,log(stories) as log_count - from cte - """).df() - - fig, ax = plt.subplots(1,1) - sns.lineplot(x=data['id'], y=data['log_count'], ax=ax) - from matplotlib.ticker import ScalarFormatter - ax.set(yscale='log', xlabel="outlet", ylabel="log(count of stories)", majorformater=ScalarFormatter) - plt.show() - - else: - raise NotImplementedError("option unrecognized") - plt.savefig(output) - if __name__ == "__main__": import scrape cli.add_command(scrape.download) @@ -76,7 +19,8 @@ if __name__ == "__main__": cli.add_command(bias.parse) cli.add_command(bias.load) cli.add_command(bias.normalize) - # import mine - # cli.add_command(mine.embeddings) - # cli.add_command(mine.cluster) + import mine + cli.add_command(mine.embeddings) + cli.add_command(mine.cluster) + cli.add_command(mine.plot) cli() diff --git a/src/mine.py b/src/mine.py index 3bc74dd..d7f1d29 100644 --- a/src/mine.py +++ b/src/mine.py @@ -2,6 +2,11 @@ from data import data_dir, connect import numpy as np import sklearn from sklearn.cluster import MiniBatchKMeans +import click +from pathlib import Path +import seaborn as sns +import matplotlib.pyplot as plt +from enum import Enum, auto @click.command(name="mine:embeddings") @@ -63,6 +68,49 @@ def cluster(): len(stories) data.shape +class PlotName(str, Enum): + TitleLength = "title_len" + OutletStories = "outlet_stories" + +@click.command(name='mine:plot') +@click.option('-n', '--name', required=True, type=click.Choice(PlotName)) +@click.option('-o', '--output', required=False, type=click.Path()) +def plot(name: PlotName, output: Path): + output = output if output else APP_DIR / f'docs/{name}.png' + if name == PlotName.TitleLength: + fig, ax = plt.subplots(1,1) + data = db.sql(""" + select + length(title) as len + from stories + """).df() + sns.histplot(x=data['len'], bins=50, ax=ax[0]) + ax[0].set(ylabel="count", xlabel="title length") + elif name == PlotName.OutletStories: + + data = db.sql(""" + with cte as ( + select + count(1) as stories + from stories + group by outlet + ) + select + row_number() over(order by stories desc) as id + ,log(stories) as log_count + from cte + """).df() + + fig, ax = plt.subplots(1,1) + sns.lineplot(x=data['id'], y=data['log_count'], ax=ax) + from matplotlib.ticker import ScalarFormatter + ax.set(yscale='log', xlabel="outlet", ylabel="log(count of stories)", majorformater=ScalarFormatter) + plt.show() + + else: + raise NotImplementedError("option unrecognized") + plt.savefig(output) + def main(): db.sql(""" select @@ -83,17 +131,18 @@ def main(): ,sum(length(title)) as characters from cte """).fetchall() - """ - let's calculate the size of the word embeddings stored as a list in the database - db.sql(""" - with cte as ( - select - distinct title - from stories - ) + # let's calculate the size of the word embeddings stored as a list in the database + # db.sql(""" + # with cte as ( + # select + # distinct title + # from stories + # ) + # ... + # """) - db.sql(""" - select - count(distinct url) - from stories - """) + # db.sql(""" + # select + # count(distinct url) + # from stories + # """)