synth-net / notebooks /marimo-demo.py
github-actions
Sync from GitHub (CI)
6ca4b94
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "marimo",
# "matplotlib==3.10.1",
# "numpy==2.2.4",
# ]
# ///
import marimo
__generated_with = "0.14.10"
app = marimo.App(css_file="theme.marimo.css", html_head_file="")
@app.cell
def _():
import marimo as mo
return (mo,)
@app.cell
def _(mo):
get_applicant_count, set_applicant_count = mo.state(500)
return get_applicant_count, set_applicant_count
@app.cell
def _(get_applicant_count, mo, set_applicant_count):
applicant_count = mo.ui.number(
start=1,
step=1,
value=get_applicant_count(),
full_width=True,
label="Number of applicants",
on_change=set_applicant_count,
)
return (applicant_count,)
@app.cell
def _(get_applicant_count, mo):
hire_count = mo.ui.number(
start=1,
stop=get_applicant_count() - 1,
step=1,
value=75,
full_width=True,
label="Hire count",
)
return (hire_count,)
@app.cell
def _(mo):
base_rate = mo.ui.slider(
value=40, start=1, stop=100, step=1, full_width=True, label="Base rate"
)
return (base_rate,)
@app.cell
def _(mo):
validity = mo.ui.slider(
value=0.18, start=0.01, stop=1.00, step=0.01, full_width=True, label="Validity"
)
return (validity,)
@app.cell
def _(applicant_count, base_rate, hire_count, validity):
import requests
url = "http://localhost:7860/api/v1/demo/r/demo"
params = {
"applicant_count": applicant_count.value,
"hire_count": hire_count.value,
"base_rate": base_rate.value,
"validity": validity.value,
}
headers = {"accept": "application/json"}
response = requests.post(url, params=params, headers=headers)
return (response,)
@app.cell
def _(response):
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
json_data = response.json()
data = pd.DataFrame(
{
"var": list(json_data.keys()),
"label": [
"Rightfully hired",
"Incorrectly hired",
"Rightfully rejected",
"Incorrectly rejected",
],
"value": list(json_data.values()),
}
)
data["label_with_value"] = data["label"] + " (" + data["value"].astype(str) + ")"
fig = make_subplots(
rows=1,
cols=2,
specs=[[{"type": "pie"}, {"type": "pie"}]],
subplot_titles=("Hiring Outcome", "Rejection Outcome"),
horizontal_spacing=0.1,
vertical_spacing=0.15,
)
hired_data = data[data["var"].isin(["true_positives", "false_positives"])]
hired_values = hired_data["value"].tolist()
hired_labels = hired_data["label_with_value"].tolist()
rejected_data = data[data["var"].isin(["true_negatives", "false_negatives"])]
rejected_values = rejected_data["value"].tolist()
rejected_labels = rejected_data["label_with_value"].tolist()
fig.add_trace(
go.Pie(
values=hired_values,
labels=hired_labels,
hole=0.5,
marker_colors=["#6b71ed", "#403c5d"],
name="Hiring",
legendgroup="hiring",
showlegend=True,
hoverinfo="none",
),
row=1,
col=1,
)
fig.add_trace(
go.Pie(
values=rejected_values,
labels=rejected_labels,
hole=0.5,
marker_colors=["#6b71ed", "#403c5d"],
name="Rejection",
legendgroup="rejection",
showlegend=True,
hoverinfo="none",
),
row=1,
col=2,
)
fig.update_layout(
autosize=True,
margin=dict(l=0, r=0, t=60, b=20),
legend=dict(
orientation="h",
yanchor="top",
y=-0.1,
xanchor="center",
x=0.5,
font=dict(size=12),
),
title=dict(text=None, x=0.5, font=dict(size=16)),
height=400,
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
)
fig.update_annotations(font_size=14, y=1.08)
return (fig,)
@app.cell
def _(applicant_count, base_rate, hire_count, mo, validity):
with mo.status.spinner(subtitle="Loading data ...") as _spinner:
mo.vstack([applicant_count, hire_count, base_rate, validity])
return
@app.cell
def _(applicant_count, base_rate, hire_count, mo, validity):
mo.vstack([applicant_count, hire_count, base_rate, validity])
return
@app.cell
def _(fig, mo, response):
if response.status_code == 200:
fig.show(
config={"displayModeBar": False, "staticPlot": True, "responsive": True}
)
else:
error_msg = mo.md(f"{mo.icon('lucide:ban')} Error in server response")
mo.output.append(error_msg)
return
if __name__ == "__main__":
app.run()