def base_fig3():
# Create a figure and a 3D Axes
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection='3d')
ax.set_xlabel("$\\beta_1$", labelpad=0)
ax.set_ylabel("$\\beta_2$", labelpad=0)
ax.set_zlim(0, 500)
ax.tick_params(axis='x', pad=0)
ax.tick_params(axis='y', pad=0)
ax.set_xlim(*beta_range)
ax.set_ylim(*beta_range)
# draw axes
ax.plot(beta_range, [0, 0], color='k')
ax.plot([0, 0], beta_range, color='k')
return fig, ax
def plot3d(reg: Reg, t=3):
fig, ax = base_fig3()
# surface
beta0 = np.linspace(*beta_range, 100)
beta1 = np.linspace(*beta_range, 100)
B0, B1 = np.meshgrid(beta0, beta1)
Z = loss(B0, B1, cx=cx, cy=cy)
ax.plot_surface(B0, B1, Z, alpha=0.7, cmap='coolwarm', vmax=vmax)
# contours
ax.plot([cx], [cy], marker='x', markersize=10, color='black')
ax.contour(B0, B1, Z, levels=50, linewidths=.5, cmap='coolwarm', zdir='z', offset=0, vmax=vmax)
# minima within regularization shape
mx, my = argmin_within_constraint(reg, t)
ax.plot([mx], [my], marker='.', markersize=10, color='r')
# regularization contraints
reg_shape = make_reg_shape(reg, t, color="black")
ax.add_patch(reg_shape)
art3d.pathpatch_2d_to_3d(reg_shape, z=0)
ax.view_init(elev=39, azim=-106)
plt.tight_layout()
plt.show()