# 在matplotlib中绘制最小二乘估计函数的等高线图

``````import jax.numpy as jnp
import numpy as np

def make_mse(x, t):
def mse(w,b):
return np.sum(jnp.power(x.dot(w) + b - t, 2))/2
return mse
``````

`x``t`是不相关的情节，因为值`x`只是被乘以单值`w`各一次。

``````x = np.linspace(-1.0,1.0,500)
t = 5*x + 1

xcoord = np.linspace(-10.0,10.0,50)
ycoord = np.linspace(-10.0,10.0,50)
w1,w2 = np.meshgrid(xcoord,ycoord)

Z = make_mse(x, t)(w1,w2)
``````

``````/usr/local/lib/python3.7/dist-packages/jax/_src/lax/lax.py in dot(lhs, rhs, precision, preferred_element_type)
634   else:
635     raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
--> 636         lhs.shape, rhs.shape))
637
638

TypeError: Incompatible shapes for dot: got (1000, 1) and (50, 50).
``````

``````import numpy as np

def make_mse(x, t):
def mse(w,b):
return np.power(x.dot(w) + b - t, 2)
return mse

x = np.linspace(-1.0,1.0,500)
t = 5*x + 1

xcoord = np.linspace(-10.0,10.0,500)
ycoord = np.linspace(-10.0,10.0,500)
w1,w2 = np.meshgrid(xcoord,ycoord)

Z = make_mse(x, t)(w1,w2)
plt.contourf(w1,w2,Z)
``````

0条评论