Geometric Intuition for Jensen’s Inequality

convex function
jensen
Author

Madiyar Aitbayev

Published

January 4, 2025

Geometric Intuition for Jensen’s Inequality

Introduction

Jensen’s inequality is fundamental in many fields, including machine learning and statistics. For example, it is useful in the diffusion models paper for understanding the variational lower bound. In this post, I will give a simple geometric intuition for Jensen’s inequality.

Feel free to leave feedback on my telegram channel.

Setup

The post contains collapsed code sections that are used to produce the visualizations. They’re optional, hence collapsed.

code for fig_jensen_inequality
import itertools

import numpy as np
import plotly.graph_objects as go


def alpha_profiles(n: int):
    if n == 2:
        space = np.linspace(0.01, 0.99, 100)
        return np.column_stack((space, 1.0 - space))
    space = np.linspace(0.01, 0.99, 15 - max(0, (n - 3) * 5))
    space_prod = itertools.product(*[space for _ in range(n - 1)])
    profiles = np.array(list(space_prod))
    profiles = profiles[np.sum(profiles, axis=1) < 1.0]
    return np.concatenate([profiles, 1 - np.sum(profiles, axis=1).reshape(-1, 1)], axis=1)


def fig_jensen_inequality(f, x_range: np.array, x: np.array, showlegend: bool = True):
    points = np.column_stack([x, f(x)])
    n = len(points)
    steps = []
    hull_points = []
    titles = []
    for index, alphas in enumerate(alpha_profiles(n)):
        hp = np.average(points, weights=alphas, axis=0)
        hull_points.append(hp)
        title = ",".join(["\\lambda_" + f"{i + 1}={a:.2f}" for i, a in enumerate(alphas)])
        title = f"${title}$"
        titles.append(title)
        step = dict(name=index, label=index, method="update",
                    args=[{
                        "x": [[hp[0]], [hp[0], hp[0]]],
                        "y": [[hp[1]], [f(hp[0]), hp[1]]],
                    }, {"title": title}, [2, 3]])
        steps.append(step)
    active_index = len(steps) // 2
    sliders = [dict(active=len(steps) // 2, steps=steps)]
    return go.Figure(data=[
        go.Scatter(
            name="f", x=x_range, y=f(x_range), hoverinfo="none"
        ),
        go.Scatter(
            name="Convex Hull", x=np.append(points[:, 0], points[0][0]),
            y=np.append(points[:, 1], points[0][1]),
            fillcolor="rgba(239, 85, 59, 0.2)", fill="toself", mode="lines",
            line=dict(width=3), hoverinfo="none",
            showlegend=points.shape[0] > 2
        ),
        go.Scatter(
            name="$(\\sum \\lambda_i x_i, \\sum \\lambda_i f(x_i))$",
            x=[hull_points[active_index][0]],
            y=[hull_points[active_index][1]],
            mode="markers+text",
            text=["$(\\sum \\lambda_i x_i, \\sum \\lambda_i f(x_i))$"],
            textposition="top center",
            hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>",
            marker={"size": 20, "color": "black"},
            legendrank=1001,
        ),
        go.Scatter(
            x=[hull_points[active_index][0], hull_points[active_index][0]],
            y=[f(hull_points[active_index][0]), hull_points[active_index][1]],
            mode="lines",
            textposition="bottom center",
            hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>",
            line={"color": "black", "dash": "dot", "width": 1},
            showlegend=False
        ),
        go.Scatter(
            name="$(x_i, f(x_i))$",
            x=points[:, 0], y=points[:, 1],
            mode="markers+text",
            marker={"size": 20},
            text=[f"$(x_{i},f(x_{i}))$" for i in range(1, n + 1)],
            textposition="top center",
            hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>"
        ),

    ], layout=go.Layout(
        title=titles[active_index],
        xaxis=dict(fixedrange=True),
        yaxis=dict(fixedrange=True, scaleanchor="x", scaleratio=1),
        sliders=sliders,
        legend=dict(
            yanchor="top",
            xanchor="right",
            x=1,
            y=1
        ),
        margin=dict(l=5, r=5, t=50, b=50),
        showlegend=showlegend
    ))


def sample_parabola(x):
    return 0.15 * (x - 15) ** 2 + 15

Convex Function

A function is a convex function when the line segment joining any two points on the function graph lies above or on the graph. In the simplest term, a convex function is shaped like \(\cup\) and a concave function is shaped like \(\cap\). If f is convex, then -f is concave.

A visualization from Wikipedia:

display image from Wikipedia
from IPython.display import Image
Image(url='https://upload.wikimedia.org/wikipedia/commons/c/c7/ConvexFunction.svg', width=400)

Definition

A function is called convex if the following holds:

\[ f(\lambda x_1 + (1-\lambda) x_2) \le \lambda f(x_1) + (1-\lambda) f(x_2) \]

and concave when:

\[ f(\lambda x_1 + (1-\lambda) x_2) \ge \lambda f(x_1) + (1-\lambda) f(x_2) \]

We will give geometric intuition for this definition in the next section.

Geometric Intuition

plots jensen’s inequality for n=2
fig = fig_jensen_inequality(
    f=sample_parabola,
    x_range=np.linspace(0, 30, 100),
    x=np.array([2, 22]),
    showlegend=False
)
fig.show(renderer="iframe")

An interactive visualization of the convex function: \(f(x)=0.15(x - 15)^2 + 15\). We will use the same parabola during this post unless stated otherwise. You can use the slider to try different values of (\(\lambda_1\), \(\lambda_2)\), where \(\lambda_2=1-\lambda_1\).

We have a line segment that joins \((x_1, f(x_1))\) and \((x_2, f(x_2))\). We can sample any point along the line segment with \((\lambda_1 x_1 + \lambda_2 x_2, \lambda_1 f(x_1) + \lambda_2 f(x_2))\). For example:

  • When \(\lambda_1=1\), we get the first point
  • When \(\lambda_1=0\), we get the second point
  • And when \(\lambda_1=0.5\), we get the middle point of the line segment
  • and so on… Try the slider above!

This point is visualized with a black point above. Let’s name it as \(A\).

The point where the function graph intersects with the dotted line segment is described by: \((\lambda_1 x_1 + \lambda_2 x_2, f(\lambda_1 x_1 + \lambda_2 x_2))\). Let’s name it as \(B\).

Then, the definition above is just asserting that \(B_y \le A_y\) and we also have \(A_x = B_x\). Note that we are only showing a single line segment, but this statement should be true for all similar line segments.

Jensen’s Inequality

Jensen’s inequality is a generalization of the above convex function definition for more than 2 points.

Definition

Assume we have a convex function \(f\) and \(x_1, x_2, \cdots, x_n\) in \(f\)’s domain, and also positive weights \(\lambda_1, \lambda_2, \cdots, \lambda_n\) where \(\sum_{i=1}^n \lambda_i = 1\). Then Jensen’s inequality can be stated as:

\[ f(\sum_{i=1}^n \lambda_i x_i) \le \sum_{i=1}^n \lambda_i f(x_i) \]

The equation is flipped for a concave function g:

\[ g(\sum_{i=1}^n \lambda_i x_i) \ge \sum_{i=1}^n \lambda_i g(x_i) \]

Note that we arrive at the same definition for convex function when \(n=2\).

Geometric Intuition

A numerous proofs are already available by other posts. I encourage you to checkout the following resources:

Here I describe a geometric intuition, which resonates more with me.

Triangle

Let’s start with a triangle, i.e., \(n=3\):

plots jensen’s inequality for n=3
fig = fig_jensen_inequality(
    f=sample_parabola,
    x_range=np.linspace(0, 30, 100),
    x=np.array([2, 13, 25]),
    showlegend=False
)
fig.show(renderer="iframe")

As before, you can use the slider to try different values of \((\lambda_1, \lambda_2, \lambda_3)\) where \(\lambda_1+\lambda_2+\lambda_3=1\).

We have a triangle that connects the points: \((x_1, f(x_1)), (x_2, f(x_2)), (x_3, f(x_3))\).

In the \(n=2\) case, we used \(\lambda_1\) and \(\lambda_2\) to sample a point along the line segment. In this case, it is similar, but we can sample any point inside or on the boundaries of the triangle with:

\[ \left(\lambda_1x_1+\lambda_2x_2+\lambda_3x_3, \lambda_1f(x_1)+\lambda_2f(x_2)+\lambda_3f(x_3)\right) \]

For example:

  • When \(\lambda_i=1\) where \(i \in \{1, 2, 3\}\), we get the point \((x_i, f(x_i))\)
  • When \(\lambda_1=\lambda_2=\lambda_3=\frac{1}{3}\), we get the center of mass of the triangle

The black point in the visualization describes this point. Let’s name it as A.

Note that (\(\lambda_1\), \(\lambda_2\), \(\lambda_3\)) describes the barycentric coordinate system. You don’t need to know it in this post, just sharing in case you’re already familiar with it.

The point where the parabola meets the dotted line segment is described by:

\[ (\lambda_1x_1+\lambda_2x_2+\lambda_3x_3, f(\lambda_1x_1+\lambda_2x_2+\lambda_3x_3)) \]

If we name this point as B, then it is not difficult to see that Jensen’s inequality is the same as \(B_y \le A_y\).

Four Points or More

It is easy to generalize for \(n>3\). I am adding it here for the sake of completeness:

plots jensen’s inequality for n=4
fig = fig_jensen_inequality(
    f=sample_parabola,
    x_range=np.linspace(0, 30, 100),
    x=np.array([2, 13, 22, 25]),
    showlegend=False
)
fig.show(renderer="iframe")

In the general case, \((\sum_{i=1}^n \lambda_ix_i, \sum_{i=1}^n \lambda_if(x_i))\) describes a point inside or on the boundary of the convex hull enclosing the points: \((x_1, f(x_1)), (x_2, f(x_2)), \cdots, (x_n, f(x_n))\). The convex hull is always above or on the graph.

A few closing notes:

  • The convex hull may have any number of points, including n → ∞
  • We closely approximate the convex function in some interval with the convex hull as n approaches infinity
  • The convexity definitions for functions and polygons are the same once we have enough points, i.e., n → ∞
  • Jensen’s inequality is useful in a probability theory setting, since \(\sum_{i=1}^n \lambda_i = 1\), including the continuous form with n → ∞.

Applications

The arithmetic mean-geometric mean inequality (AM-GM inequality) states that: \[ \frac{x_1+x_2+\cdots+x_n}{n} \ge \sqrt[n]{x_1x_2\cdots x_n} \]

Let’s prove with Jensen’s inequality by rewriting the above with \(\lambda_1=\lambda_2=\cdots=\lambda_n=\frac{1}{n}\):

\[ \sum_{i=1}^n \lambda_i x_i \ge \prod_{i=1}^n x_i^{\lambda_i} \]

Since log is a concave and monotonic function, we can apply log to both sides.

\[ \log(\sum_{i=1}^n \lambda_i x_i) \ge \log(\prod_{i=1}^n x_i^{\lambda_i}) = \sum_{i=1}^n \lambda_i \log(x_i) \]

The above equation is valid due to Jensen’s inequality. Note that the same proof works for the weighted version since the proof does not rely on the fact that \(\lambda_i=\frac{1}{n}\) for all \(i=1,2,\cdots,n\).

The End

I hope you enjoyed this post. You can ask further questions on my telegram channel