fix mine cli.
This commit is contained in:
parent
6dba519443
commit
c38a5455a8
64
src/cli.py
64
src/cli.py
|
@ -1,66 +1,9 @@
|
|||
import polars as pl
|
||||
import duckdb
|
||||
import toml
|
||||
import os
|
||||
from pathlib import Path
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from enum import Enum, auto
|
||||
import click
|
||||
|
||||
DATA_DIR = Path(os.environ['DATA_MINING_DATA_DIR'])
|
||||
APP_DIR = Path(os.environ['DATA_MINING_APP_DIR'])
|
||||
|
||||
db = duckdb.connect(str(DATA_DIR / 'project.duckdb'))
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
...
|
||||
|
||||
|
||||
class PlotName(str, Enum):
|
||||
TitleLength = "title_len"
|
||||
OutletStories = "outlet_stories"
|
||||
|
||||
@cli.command()
|
||||
@click.option('-n', '--name', required=True, type=click.Choice(PlotName))
|
||||
@click.option('-o', '--output', required=False, type=click.Path())
|
||||
def plot(name: PlotName, output: Path):
|
||||
output = output if output else APP_DIR / f'docs/{name}.png'
|
||||
if name == PlotName.TitleLength:
|
||||
fig, ax = plt.subplots(1,1)
|
||||
data = db.sql("""
|
||||
select
|
||||
length(title) as len
|
||||
from stories
|
||||
""").df()
|
||||
sns.histplot(x=data['len'], bins=50, ax=ax[0])
|
||||
ax[0].set(ylabel="count", xlabel="title length")
|
||||
elif name == PlotName.OutletStories:
|
||||
|
||||
data = db.sql("""
|
||||
with cte as (
|
||||
select
|
||||
count(1) as stories
|
||||
from stories
|
||||
group by outlet
|
||||
)
|
||||
select
|
||||
row_number() over(order by stories desc) as id
|
||||
,log(stories) as log_count
|
||||
from cte
|
||||
""").df()
|
||||
|
||||
fig, ax = plt.subplots(1,1)
|
||||
sns.lineplot(x=data['id'], y=data['log_count'], ax=ax)
|
||||
from matplotlib.ticker import ScalarFormatter
|
||||
ax.set(yscale='log', xlabel="outlet", ylabel="log(count of stories)", majorformater=ScalarFormatter)
|
||||
plt.show()
|
||||
|
||||
else:
|
||||
raise NotImplementedError("option unrecognized")
|
||||
plt.savefig(output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import scrape
|
||||
cli.add_command(scrape.download)
|
||||
|
@ -76,7 +19,8 @@ if __name__ == "__main__":
|
|||
cli.add_command(bias.parse)
|
||||
cli.add_command(bias.load)
|
||||
cli.add_command(bias.normalize)
|
||||
# import mine
|
||||
# cli.add_command(mine.embeddings)
|
||||
# cli.add_command(mine.cluster)
|
||||
import mine
|
||||
cli.add_command(mine.embeddings)
|
||||
cli.add_command(mine.cluster)
|
||||
cli.add_command(mine.plot)
|
||||
cli()
|
||||
|
|
75
src/mine.py
75
src/mine.py
|
@ -2,6 +2,11 @@ from data import data_dir, connect
|
|||
import numpy as np
|
||||
import sklearn
|
||||
from sklearn.cluster import MiniBatchKMeans
|
||||
import click
|
||||
from pathlib import Path
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
@click.command(name="mine:embeddings")
|
||||
|
@ -63,6 +68,49 @@ def cluster():
|
|||
len(stories)
|
||||
data.shape
|
||||
|
||||
class PlotName(str, Enum):
|
||||
TitleLength = "title_len"
|
||||
OutletStories = "outlet_stories"
|
||||
|
||||
@click.command(name='mine:plot')
|
||||
@click.option('-n', '--name', required=True, type=click.Choice(PlotName))
|
||||
@click.option('-o', '--output', required=False, type=click.Path())
|
||||
def plot(name: PlotName, output: Path):
|
||||
output = output if output else APP_DIR / f'docs/{name}.png'
|
||||
if name == PlotName.TitleLength:
|
||||
fig, ax = plt.subplots(1,1)
|
||||
data = db.sql("""
|
||||
select
|
||||
length(title) as len
|
||||
from stories
|
||||
""").df()
|
||||
sns.histplot(x=data['len'], bins=50, ax=ax[0])
|
||||
ax[0].set(ylabel="count", xlabel="title length")
|
||||
elif name == PlotName.OutletStories:
|
||||
|
||||
data = db.sql("""
|
||||
with cte as (
|
||||
select
|
||||
count(1) as stories
|
||||
from stories
|
||||
group by outlet
|
||||
)
|
||||
select
|
||||
row_number() over(order by stories desc) as id
|
||||
,log(stories) as log_count
|
||||
from cte
|
||||
""").df()
|
||||
|
||||
fig, ax = plt.subplots(1,1)
|
||||
sns.lineplot(x=data['id'], y=data['log_count'], ax=ax)
|
||||
from matplotlib.ticker import ScalarFormatter
|
||||
ax.set(yscale='log', xlabel="outlet", ylabel="log(count of stories)", majorformater=ScalarFormatter)
|
||||
plt.show()
|
||||
|
||||
else:
|
||||
raise NotImplementedError("option unrecognized")
|
||||
plt.savefig(output)
|
||||
|
||||
def main():
|
||||
db.sql("""
|
||||
select
|
||||
|
@ -83,17 +131,18 @@ def main():
|
|||
,sum(length(title)) as characters
|
||||
from cte
|
||||
""").fetchall()
|
||||
"""
|
||||
let's calculate the size of the word embeddings stored as a list in the database
|
||||
db.sql("""
|
||||
with cte as (
|
||||
select
|
||||
distinct title
|
||||
from stories
|
||||
)
|
||||
# let's calculate the size of the word embeddings stored as a list in the database
|
||||
# db.sql("""
|
||||
# with cte as (
|
||||
# select
|
||||
# distinct title
|
||||
# from stories
|
||||
# )
|
||||
# ...
|
||||
# """)
|
||||
|
||||
db.sql("""
|
||||
select
|
||||
count(distinct url)
|
||||
from stories
|
||||
""")
|
||||
# db.sql("""
|
||||
# select
|
||||
# count(distinct url)
|
||||
# from stories
|
||||
# """)
|
||||
|
|
Loading…
Reference in New Issue