fix mine cli.

This commit is contained in:
Matt 2023-04-22 18:27:11 -07:00
parent 6dba519443
commit c38a5455a8
2 changed files with 66 additions and 73 deletions

View File

@ -1,66 +1,9 @@
import polars as pl
import duckdb
import toml
import os
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt
from enum import Enum, auto
import click
DATA_DIR = Path(os.environ['DATA_MINING_DATA_DIR'])
APP_DIR = Path(os.environ['DATA_MINING_APP_DIR'])
db = duckdb.connect(str(DATA_DIR / 'project.duckdb'))
@click.group()
def cli():
...
class PlotName(str, Enum):
TitleLength = "title_len"
OutletStories = "outlet_stories"
@cli.command()
@click.option('-n', '--name', required=True, type=click.Choice(PlotName))
@click.option('-o', '--output', required=False, type=click.Path())
def plot(name: PlotName, output: Path):
output = output if output else APP_DIR / f'docs/{name}.png'
if name == PlotName.TitleLength:
fig, ax = plt.subplots(1,1)
data = db.sql("""
select
length(title) as len
from stories
""").df()
sns.histplot(x=data['len'], bins=50, ax=ax[0])
ax[0].set(ylabel="count", xlabel="title length")
elif name == PlotName.OutletStories:
data = db.sql("""
with cte as (
select
count(1) as stories
from stories
group by outlet
)
select
row_number() over(order by stories desc) as id
,log(stories) as log_count
from cte
""").df()
fig, ax = plt.subplots(1,1)
sns.lineplot(x=data['id'], y=data['log_count'], ax=ax)
from matplotlib.ticker import ScalarFormatter
ax.set(yscale='log', xlabel="outlet", ylabel="log(count of stories)", majorformater=ScalarFormatter)
plt.show()
else:
raise NotImplementedError("option unrecognized")
plt.savefig(output)
if __name__ == "__main__":
import scrape
cli.add_command(scrape.download)
@ -76,7 +19,8 @@ if __name__ == "__main__":
cli.add_command(bias.parse)
cli.add_command(bias.load)
cli.add_command(bias.normalize)
# import mine
# cli.add_command(mine.embeddings)
# cli.add_command(mine.cluster)
import mine
cli.add_command(mine.embeddings)
cli.add_command(mine.cluster)
cli.add_command(mine.plot)
cli()

View File

@ -2,6 +2,11 @@ from data import data_dir, connect
import numpy as np
import sklearn
from sklearn.cluster import MiniBatchKMeans
import click
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt
from enum import Enum, auto
@click.command(name="mine:embeddings")
@ -63,6 +68,49 @@ def cluster():
len(stories)
data.shape
class PlotName(str, Enum):
TitleLength = "title_len"
OutletStories = "outlet_stories"
@click.command(name='mine:plot')
@click.option('-n', '--name', required=True, type=click.Choice(PlotName))
@click.option('-o', '--output', required=False, type=click.Path())
def plot(name: PlotName, output: Path):
output = output if output else APP_DIR / f'docs/{name}.png'
if name == PlotName.TitleLength:
fig, ax = plt.subplots(1,1)
data = db.sql("""
select
length(title) as len
from stories
""").df()
sns.histplot(x=data['len'], bins=50, ax=ax[0])
ax[0].set(ylabel="count", xlabel="title length")
elif name == PlotName.OutletStories:
data = db.sql("""
with cte as (
select
count(1) as stories
from stories
group by outlet
)
select
row_number() over(order by stories desc) as id
,log(stories) as log_count
from cte
""").df()
fig, ax = plt.subplots(1,1)
sns.lineplot(x=data['id'], y=data['log_count'], ax=ax)
from matplotlib.ticker import ScalarFormatter
ax.set(yscale='log', xlabel="outlet", ylabel="log(count of stories)", majorformater=ScalarFormatter)
plt.show()
else:
raise NotImplementedError("option unrecognized")
plt.savefig(output)
def main():
db.sql("""
select
@ -83,17 +131,18 @@ def main():
,sum(length(title)) as characters
from cte
""").fetchall()
"""
let's calculate the size of the word embeddings stored as a list in the database
db.sql("""
with cte as (
select
distinct title
from stories
)
# let's calculate the size of the word embeddings stored as a list in the database
# db.sql("""
# with cte as (
# select
# distinct title
# from stories
# )
# ...
# """)
db.sql("""
select
count(distinct url)
from stories
""")
# db.sql("""
# select
# count(distinct url)
# from stories
# """)