Jensen’s inequality is fundamental in many fields, including machine learning and statistics. For example, it is useful in the diffusion models paper for understanding the variational lower bound. In this post, I will give a simple geometric intuition for Jensen’s inequality.
A function is a convex function when the line segment joining any two points on the function graph lies above or on the graph. In the simplest term, a convex function is shaped like \(\cup\) and a concave function is shaped like \(\cap\). If f is convex, then -f is concave.
An interactive visualization of the convex function: \(f(x)=0.15(x - 15)^2 + 15\). We will use the same parabola during this post unless stated otherwise. You can use the slider to try different values of (\(\lambda_1\), \(\lambda_2)\), where \(\lambda_2=1-\lambda_1\).
We have a line segment that joins \((x_1, f(x_1))\) and \((x_2, f(x_2))\). We can sample any point along the line segment with \((\lambda_1 x_1 + \lambda_2 x_2, \lambda_1 f(x_1) + \lambda_2 f(x_2))\). For example:
When \(\lambda_1=1\), we get the first point
When \(\lambda_1=0\), we get the second point
And when \(\lambda_1=0.5\), we get the middle point of the line segment
and so on… Try the slider above!
This point is visualized with a black point above. Let’s name it as \(A\).
The point where the function graph intersects with the dotted line segment is described by: \((\lambda_1 x_1 + \lambda_2 x_2, f(\lambda_1 x_1 + \lambda_2 x_2))\). Let’s name it as \(B\).
Then, the definition above is just asserting that \(B_y \le A_y\) and we also have \(A_x = B_x\). Note that we are only showing a single line segment, but this statement should be true for all similar line segments.
Jensen’s Inequality
Jensen’s inequality is a generalization of the above convex function definition for more than 2 points.
Definition
Assume we have a convex function\(f\) and \(x_1, x_2, \cdots, x_n\) in \(f\)’s domain, and also positive weights \(\lambda_1, \lambda_2, \cdots, \lambda_n\) where \(\sum_{i=1}^n \lambda_i = 1\). Then Jensen’s inequality can be stated as:
As before, you can use the slider to try different values of \((\lambda_1, \lambda_2, \lambda_3)\) where \(\lambda_1+\lambda_2+\lambda_3=1\).
We have a triangle that connects the points: \((x_1, f(x_1)), (x_2, f(x_2)), (x_3, f(x_3))\).
In the \(n=2\) case, we used \(\lambda_1\) and \(\lambda_2\) to sample a point along the line segment. In this case, it is similar, but we can sample any point inside or on the boundaries of the triangle with:
When \(\lambda_i=1\) where \(i \in \{1, 2, 3\}\), we get the point \((x_i, f(x_i))\)
When \(\lambda_1=\lambda_2=\lambda_3=\frac{1}{3}\), we get the center of mass of the triangle
The black point in the visualization describes this point. Let’s name it as A.
Note that (\(\lambda_1\), \(\lambda_2\), \(\lambda_3\)) describes the barycentric coordinate system. You don’t need to know it in this post, just sharing in case you’re already familiar with it.
The point where the parabola meets the dotted line segment is described by:
In the general case, \((\sum_{i=1}^n \lambda_ix_i, \sum_{i=1}^n \lambda_if(x_i))\) describes a point inside or on the boundary of the convex hull enclosing the points: \((x_1, f(x_1)), (x_2, f(x_2)), \cdots, (x_n, f(x_n))\). The convex hull is always above or on the graph.
A few closing notes:
The convex hull may have any number of points, including n → ∞
We closely approximate the convex function in some interval with the convex hull as n approaches infinity
The convexity definitions for functions and polygons are the same once we have enough points, i.e., n → ∞
Jensen’s inequality is useful in a probability theory setting, since \(\sum_{i=1}^n \lambda_i = 1\), including the continuous form with n → ∞.
Applications
AM–GM inequality
The arithmetic mean-geometric mean inequality (AM-GM inequality) states that: \[
\frac{x_1+x_2+\cdots+x_n}{n} \ge \sqrt[n]{x_1x_2\cdots x_n}
\]
Let’s prove with Jensen’s inequality by rewriting the above with \(\lambda_1=\lambda_2=\cdots=\lambda_n=\frac{1}{n}\):
The above equation is valid due to Jensen’s inequality. Note that the same proof works for the weighted version since the proof does not rely on the fact that \(\lambda_i=\frac{1}{n}\) for all \(i=1,2,\cdots,n\).
The End
I hope you enjoyed this post. You can ask further questions on my telegram channel