Jensen’s inequality is fundamental in many fields, including machine learning and statistics. For example, it is useful in the diffusion models paper for understanding the variational lower bound. In this post, I will give a simple geometric intuition for Jensen’s inequality.
A function is a convex function when the line segment joining any two points on the function graph lies above or on the graph. In the simplest term, a convex function is shaped like \(\cup\) and a concave function is shaped like \(\cap\). If f is convex, then -f is concave.
for all \(0 \le \lambda \le 1\) and for all \(x_1, x_2 \in X\).
We will give geometric intuition for this definition in the next section.
Geometric Intuition
show_sample_jensen_inequality(x=[2, 22])
An interactive visualization of the convex function: \(f(x)=0.15(x - 15)^2 + 15\). We will use the same parabola during this post unless stated otherwise. You can use the slider to try different values of (\(\lambda_1\), \(\lambda_2)\), where \(\lambda_1=\lambda\) (from the definition) and \(\lambda_2=1-\lambda\).
We have a line segment that joins \((x_1, f(x_1))\) and \((x_2, f(x_2))\). We can sample any point along the line segment with \((\lambda_1 x_1 + \lambda_2 x_2, \lambda_1 f(x_1) + \lambda_2 f(x_2))\) where \(0 \le \lambda_1 \le 1\) and \(\lambda_2 = 1 - \lambda_1\). For example:
When \(\lambda_1=1\), we get the first point
When \(\lambda_1=0\), we get the second point
And when \(\lambda_1=0.5\), we get the middle point of the line segment
and so on… Try the slider above and notice how \(\lambda_1\) and \(\lambda_2\) are changing!
This point is visualized with a black point above. Let’s name it as A.
The light green point where the function graph intersects with the dotted line segment is represented by: \((\lambda_1 x_1 + \lambda_2 x_2, f(\lambda_1 x_1 + \lambda_2 x_2))\). Let’s name it as B.
Then, the definition above is just asserting that \(B_y \le A_y\) and we also have \(A_x = B_x\). Note that we are only showing a single line segment, but this statement should be true for all similar line segments and for all \(0 \le \lambda_1 \le 1\).
Jensen’s Inequality
Jensen’s inequality is a generalization of the above convex function definition for more than 2 points.
Definition
Assume we have a convex function\(f\) and \(x_1, x_2, \cdots, x_n\) in \(f\)’s domain, and also positive weights \(\lambda_1, \lambda_2, \cdots, \lambda_n\) where \(\sum_{i=1}^n \lambda_i = 1\). Then Jensen’s inequality can be stated as:
Here I describe a geometric intuition, which resonates more with me.
Triangle
Let’s start with a triangle, i.e., \(n=3\):
show_sample_jensen_inequality(x=[2, 12, 27])
As before, you can use the slider to try different values of \((\lambda_1, \lambda_2, \lambda_3)\) where \(\lambda_1+\lambda_2+\lambda_3=1\).
We have a triangle that connects the points: \((x_1, f(x_1)), (x_2, f(x_2)), (x_3, f(x_3))\).
In the \(n=2\) case, we used \(\lambda_1\) and \(\lambda_2\) to sample a point along the line segment. In this case, it is similar, but we can sample any point inside or on the boundaries of the triangle with:
where \(\lambda_1+\lambda_2+\lambda_3=1\). For example:
When \(\lambda_i=1\) where \(i \in \{1, 2, 3\}\), we get the point \((x_i, f(x_i))\)
When \(\lambda_1=\lambda_2=\lambda_3=\frac{1}{3}\), we get the center of mass of the triangle
The black point (named A) in the visualization represents this point.
Note that (\(\lambda_1\), \(\lambda_2\), \(\lambda_3\)) describes the barycentric coordinate system. You don’t need to know it in this post, just sharing in case you’re already familiar with it.
The light green point where the parabola meets the dotted line segment is represented by:
If we name this point as B, then it is not difficult to see that Jensen’s inequality is the same as \(B_y \le A_y\).
Four Points or More
It is easy to generalize for \(n>3\). I am adding it here for the sake of completeness:
show_sample_jensen_inequality(x=[2, 13, 22, 25])
In the general case, \((\sum_{i=1}^n \lambda_ix_i, \sum_{i=1}^n \lambda_if(x_i))\) describes a point inside or on the boundary of the convex hull enclosing the points: \((x_1, f(x_1)), (x_2, f(x_2)), \cdots, (x_n, f(x_n))\). The convex hull is always above or on the convex function graph, which is why Jensen’s inequality holds true.
A few closing notes:
The convex hull may have any number of points, including n → ∞
We closely approximate the convex function in some interval with the convex hull as n approaches infinity
The convexity definitions for functions and polygons are the same once we have enough points, i.e., n → ∞
Jensen’s inequality is useful in a probability theory setting, since \(\sum_{i=1}^n \lambda_i = 1\), including the continuous form with n → ∞.
Applications
AM–GM inequality
The arithmetic mean-geometric mean inequality (AM-GM inequality) states that: \[
\frac{x_1+x_2+\cdots+x_n}{n} \ge \sqrt[n]{x_1x_2\cdots x_n}
\]
Let’s prove with Jensen’s inequality by rewriting the above with \(\lambda_1=\lambda_2=\cdots=\lambda_n=\frac{1}{n}\):
The above equation is valid due to Jensen’s inequality. Note that the same proof works for the weighted version since the proof does not rely on the fact that \(\lambda_i=\frac{1}{n}\) for all \(i=1,2,\cdots,n\).
The End
I hope you enjoyed this post. You can ask further questions on my telegram channel