import itertools
from typing import Optional
import numpy as np
import plotly.graph_objects as go
def alpha_profiles(n: int):
if n == 2:
space = np.linspace(0.01, 0.99, 100)
return np.column_stack((space, 1.0 - space))
space = np.linspace(0.01, 0.99, 15 - max(0, (n - 3) * 5))
space_prod = itertools.product(*[space for _ in range(n - 1)])
profiles = np.array(list(space_prod))
profiles = profiles[np.sum(profiles, axis=1) < 1.0]
return np.concatenate([profiles, 1 - np.sum(profiles, axis=1).reshape(-1, 1)], axis=1)
def fig_jensen_inequality(f, x_range: list, x: np.array, y_range: Optional[list] = None):
x_linspace = np.linspace(x_range[0], x_range[1], 100)
points = np.column_stack([x, f(x)])
n = len(points)
steps = []
hull_points = []
titles = []
for index, alphas in enumerate(alpha_profiles(n)):
hp = np.average(points, weights=alphas, axis=0)
hull_points.append(hp)
title = ",".join(["\\lambda_" + f"{i + 1}={a:.2f}" for i, a in enumerate(alphas)])
title = f"${title}$"
titles.append(title)
step = dict(name=index, label=index, method="update",
args=[{
"x": [[hp[0], hp[0]], [hp[0]], [hp[0]]],
"y": [[f(hp[0]), hp[1]], [hp[1]], [f(hp[0])]],
}, {"title": title}, [2, 3, 4]])
steps.append(step)
active_index = len(steps) // 2
sliders = [dict(active=len(steps) // 2, steps=steps)]
return go.Figure(data=[
go.Scatter(
name="f", x=x_linspace, y=f(x_linspace), hoverinfo="none"
),
go.Scatter(
name="Convex Hull", x=np.append(points[:, 0], points[0][0]),
y=np.append(points[:, 1], points[0][1]),
fillcolor="rgba(239, 85, 59, 0.2)", fill="toself", mode="lines",
line=dict(width=3), hoverinfo="none",
showlegend=False
),
go.Scatter(
x=[hull_points[active_index][0], hull_points[active_index][0]],
y=[f(hull_points[active_index][0]), hull_points[active_index][1]],
mode="lines",
textposition="bottom center",
hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>",
line={"color": "black", "dash": "dot", "width": 1},
showlegend=False
),
go.Scatter(
name="A",
x=[hull_points[active_index][0]],
y=[hull_points[active_index][1]],
mode=f"markers+text",
text=["$(\\sum \\lambda_i x_i, \\sum \\lambda_i f(x_i))$"],
textposition="top center",
hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>",
marker={"size": 20, "color": "black"},
),
go.Scatter(
name="B",
x=[hull_points[active_index][0]],
y=[f(hull_points[active_index][0])],
mode=f"markers",
text=["B"],
textposition="bottom center",
hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>",
marker={"size": 20, "color": "#00CC96"},
),
go.Scatter(
name="$(x_i, f(x_i))$",
x=points[:, 0], y=points[:, 1],
mode="markers+text",
marker={"size": 20, "color": "#ffa15a"},
line={"color": "rgba(239, 85, 59, 0.2)"},
text=[f"$(x_{i},f(x_{i}))$" for i in range(1, n + 1)],
textposition="top center",
hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>",
showlegend=False,
),
], layout=go.Layout(
title=titles[active_index],
xaxis=dict(fixedrange=True, range=x_range),
yaxis=dict(fixedrange=True, scaleanchor="x", scaleratio=1, range=y_range),
sliders=sliders,
legend=dict(
yanchor="top",
xanchor="right",
x=1,
y=1
),
margin=dict(l=5, r=5, t=60, b=20)
))
def sample_parabola(x):
return 0.15 * (x - 15) ** 2 + 15