import marimo __generated_with = "0.9.14" app = marimo.App(width="medium") @app.cell(hide_code=True) def __(): import marimo as mo import duckdb import pandas import numpy import altair as alt import plotly.express as px mo.md("# 🤗 Hub Model Tree") return alt, duckdb, mo, numpy, pandas, px @app.cell(hide_code=True) def __(mo): mo.md("""This is powered by the [Hub Stats](https://huggingface.co/datasets/cfahlgren1/hub-stats) dataset which you can query via the [SQL Console](https://huggingface.co/datasets/cfahlgren1/hub-stats?sql_console=true). The model tree metric is where a model tags a parent model as a `base_model`. The `hub-stats` dataset gets updated daily. Try it out by putting an organization or model author in search box and hit enter.""") return @app.cell def __(duckdb): duckdb.sql("CREATE VIEW models as SELECT * FROM 'hf://datasets/cfahlgren1/hub-stats/models.parquet'") return (models,) @app.cell(hide_code=True) def __(mo): author_input = mo.ui.text(placeholder="Search...", label="Author") ctes = """ WITH author_models AS ( SELECT id FROM models WHERE author = '{}' ), model_tags AS ( SELECT id, UNNEST(tags) AS tag FROM models ) """ def get_model_children_counts(author: str) -> str: return f""" {ctes.format(author)} SELECT am.id as parent_model_id, COUNT(DISTINCT m.id) as num_direct_children FROM author_models am INNER JOIN model_tags m ON m.tag = 'base_model:' || am.id GROUP BY am.id ORDER BY num_direct_children DESC; """ def get_total_childen_count(author: str) -> str: return f""" {ctes.format(author)} SELECT COUNT(DISTINCT m.id) as num_direct_children FROM author_models am LEFT JOIN model_tags m ON m.tag = 'base_model:' || am.id """ return ( author_input, ctes, get_model_children_counts, get_total_childen_count, ) @app.cell def __(mo): mo.md("## Search by Author") return @app.cell(hide_code=True) def __(author_input, mo): mo.vstack([author_input, mo.md("_ex: meta-llama, google, mistralai, Qwen_")]) return @app.cell(hide_code=True) def __(author_input, duckdb, get_total_childen_count, mo): result = duckdb.sql(get_total_childen_count(author_input.value)).fetchall() mo.vstack([mo.md("### Direct Child Models"), mo.md(f"_The number of models that have tagged a {author_input.value} model as a `base_model`_"), mo.stat(result[0][0])]) return (result,) @app.cell(hide_code=True) def __(author_input, duckdb, get_model_children_counts): df = duckdb.sql(get_model_children_counts(author_input.value)).fetchdf() df return (df,) @app.cell(hide_code=True) def __(df, mo, px): _plot = px.bar( df, x="parent_model_id", y="num_direct_children", log_y=True ) mo.ui.plotly(_plot) return if __name__ == "__main__": app.run()