fix mine cli.

This commit is contained in:
Matt 2023-04-22 18:27:11 -07:00
parent 6dba519443
commit c38a5455a8
2 changed files with 66 additions and 73 deletions

View File

@ -1,66 +1,9 @@
import polars as pl
import duckdb
import toml
import os
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt
from enum import Enum, auto
import click import click
DATA_DIR = Path(os.environ['DATA_MINING_DATA_DIR'])
APP_DIR = Path(os.environ['DATA_MINING_APP_DIR'])
db = duckdb.connect(str(DATA_DIR / 'project.duckdb'))
@click.group() @click.group()
def cli(): def cli():
... ...
class PlotName(str, Enum):
TitleLength = "title_len"
OutletStories = "outlet_stories"
@cli.command()
@click.option('-n', '--name', required=True, type=click.Choice(PlotName))
@click.option('-o', '--output', required=False, type=click.Path())
def plot(name: PlotName, output: Path):
output = output if output else APP_DIR / f'docs/{name}.png'
if name == PlotName.TitleLength:
fig, ax = plt.subplots(1,1)
data = db.sql("""
select
length(title) as len
from stories
""").df()
sns.histplot(x=data['len'], bins=50, ax=ax[0])
ax[0].set(ylabel="count", xlabel="title length")
elif name == PlotName.OutletStories:
data = db.sql("""
with cte as (
select
count(1) as stories
from stories
group by outlet
)
select
row_number() over(order by stories desc) as id
,log(stories) as log_count
from cte
""").df()
fig, ax = plt.subplots(1,1)
sns.lineplot(x=data['id'], y=data['log_count'], ax=ax)
from matplotlib.ticker import ScalarFormatter
ax.set(yscale='log', xlabel="outlet", ylabel="log(count of stories)", majorformater=ScalarFormatter)
plt.show()
else:
raise NotImplementedError("option unrecognized")
plt.savefig(output)
if __name__ == "__main__": if __name__ == "__main__":
import scrape import scrape
cli.add_command(scrape.download) cli.add_command(scrape.download)
@ -76,7 +19,8 @@ if __name__ == "__main__":
cli.add_command(bias.parse) cli.add_command(bias.parse)
cli.add_command(bias.load) cli.add_command(bias.load)
cli.add_command(bias.normalize) cli.add_command(bias.normalize)
# import mine import mine
# cli.add_command(mine.embeddings) cli.add_command(mine.embeddings)
# cli.add_command(mine.cluster) cli.add_command(mine.cluster)
cli.add_command(mine.plot)
cli() cli()

View File

@ -2,6 +2,11 @@ from data import data_dir, connect
import numpy as np import numpy as np
import sklearn import sklearn
from sklearn.cluster import MiniBatchKMeans from sklearn.cluster import MiniBatchKMeans
import click
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt
from enum import Enum, auto
@click.command(name="mine:embeddings") @click.command(name="mine:embeddings")
@ -63,6 +68,49 @@ def cluster():
len(stories) len(stories)
data.shape data.shape
class PlotName(str, Enum):
TitleLength = "title_len"
OutletStories = "outlet_stories"
@click.command(name='mine:plot')
@click.option('-n', '--name', required=True, type=click.Choice(PlotName))
@click.option('-o', '--output', required=False, type=click.Path())
def plot(name: PlotName, output: Path):
output = output if output else APP_DIR / f'docs/{name}.png'
if name == PlotName.TitleLength:
fig, ax = plt.subplots(1,1)
data = db.sql("""
select
length(title) as len
from stories
""").df()
sns.histplot(x=data['len'], bins=50, ax=ax[0])
ax[0].set(ylabel="count", xlabel="title length")
elif name == PlotName.OutletStories:
data = db.sql("""
with cte as (
select
count(1) as stories
from stories
group by outlet
)
select
row_number() over(order by stories desc) as id
,log(stories) as log_count
from cte
""").df()
fig, ax = plt.subplots(1,1)
sns.lineplot(x=data['id'], y=data['log_count'], ax=ax)
from matplotlib.ticker import ScalarFormatter
ax.set(yscale='log', xlabel="outlet", ylabel="log(count of stories)", majorformater=ScalarFormatter)
plt.show()
else:
raise NotImplementedError("option unrecognized")
plt.savefig(output)
def main(): def main():
db.sql(""" db.sql("""
select select
@ -83,17 +131,18 @@ def main():
,sum(length(title)) as characters ,sum(length(title)) as characters
from cte from cte
""").fetchall() """).fetchall()
""" # let's calculate the size of the word embeddings stored as a list in the database
let's calculate the size of the word embeddings stored as a list in the database # db.sql("""
db.sql(""" # with cte as (
with cte as ( # select
select # distinct title
distinct title # from stories
from stories # )
) # ...
# """)
db.sql(""" # db.sql("""
select # select
count(distinct url) # count(distinct url)
from stories # from stories
""") # """)