当前位置:网站首页>线性回归(公式推导+numpy实现)
线性回归(公式推导+numpy实现)
2022-07-22 07:34:00 【西红柿爱喝水】
公式推导
含义 | 公式 | 维度 |
---|---|---|
输入(矩阵形式) | X = [ − x ( 1 ) T − − x ( 2 ) T − ⋯ − x ( i ) T − ⋯ − x ( m ) T − ] \mathbf X= \begin{bmatrix}-\mathbf {x^{(1)}}^T - \\-\mathbf {x^{(2)}}^T- \\\cdots\\-\mathbf {x^{(i)}}^T-\\\cdots\\-\mathbf {x^{(m)}}^T-\end{bmatrix} X=⎣⎢⎢⎢⎢⎢⎢⎢⎢⎡−x(1)T−−x(2)T−⋯−x(i)T−⋯−x(m)T−⎦⎥⎥⎥⎥⎥⎥⎥⎥⎤ | m × n m\times n m×n |
输入 | x ( i ) = [ x 1 ( i ) x 2 ( i ) ⋯ x j ( i ) ⋯ x n ( i ) ] T \mathbf x^{(i)}=\begin{bmatrix} x_{1}^{(i)} & x_{2}^{(i)} & \cdots & x_{j}^{(i)} & \cdots & x_{n}^{(i)}\end{bmatrix}^T x(i)=[x1(i)x2(i)⋯xj(i)⋯xn(i)]T | n × 1 n\times 1 n×1 |
标签 | y = [ y ( 1 ) y ( 2 ) ⋯ y ( i ) ⋯ y ( m ) ] T \mathbf y={\begin{bmatrix} y^{(1)} & y^{(2)} & \cdots & y^{(i)} &\cdots &y^{(m)}\end{bmatrix}}^T y=[y(1)y(2)⋯y(i)⋯y(m)]T | m × 1 m\times 1 m×1 |
参数 | w = [ w 1 w 2 ⋯ w j ⋯ w n ] T \mathbf w={\begin{bmatrix}w_{1} & w_{2} & \cdots & w_{j} & \cdots & w_{n}\end{bmatrix}}^T w=[w1w2⋯wj⋯wn]T | n × 1 n\times 1 n×1 |
输出 | f w , b ( x ( i ) ) = w T x ( i ) + b f_{w,b}(\mathbf x^{(i)}) = {\mathbf w}^T{\mathbf x}^{(i)} + b fw,b(x(i))=wTx(i)+b | 标量 |
输出(矩阵形式) | f w , b ( X ) = X w + b f_{w,b}(\mathbf X) = \mathbf X \mathbf w+ b fw,b(X)=Xw+b | m × 1 m\times 1 m×1 |
损失函数 | c o s t ( i ) = ( f w , b ( x ( i ) ) − y ( i ) ) 2 cost^{(i)} = (f_{w,b}(\mathbf x^{(i)}) - y^{(i)})^2 cost(i)=(fw,b(x(i))−y(i))2 | 标量 |
代价函数 | J ( w , b ) = 1 2 m ∑ i = 1 m c o s t ( i ) + λ 2 m ∑ j = 1 n w j 2 = 1 2 m ( f w , b ( X ) − y ) T ( f w , b ( X ) − y ) + λ 2 m w T w \begin{aligned}J(\mathbf w,b) &= \frac{1}{2m} \sum\limits_{i = 1}^{m} cost^{(i)}+\frac{\lambda}{2m}\sum\limits_{j = 1}^{n} w_{j}^2\\&=\frac{1}{2m}(f_{w,b}(\mathbf X)-\mathbf y)^T(f_{w,b}(\mathbf X)-\mathbf y)+\frac{\lambda}{2m}\mathbf w^T\mathbf w\end{aligned} J(w,b)=2m1i=1∑mcost(i)+2mλj=1∑nwj2=2m1(fw,b(X)−y)T(fw,b(X)−y)+2mλwTw | 标量 |
梯度下降 | w j : = w j − α ∂ J ( w , b ) ∂ w j b : = b − α ∂ J ( w , b ) ∂ b ∂ J ( w , b ) ∂ w j = 1 m ∑ i = 1 m ( f w , b ( x ( i ) ) − y ( i ) ) x j ( i ) + λ m w j ∂ J ( w , b ) ∂ b = 1 m ∑ i = 1 m ( f w , b ( x ( i ) ) − y ( i ) ) \begin{aligned}w_j :&= w_j - \alpha \frac{\partial J(\mathbf{w},b)}{\partial w_j}\\ b :&= b - \alpha \frac{\partial J(\mathbf{w},b)}{\partial b}\\\frac{\partial J(\mathbf{w},b)}{\partial w_j} &= \frac{1}{m} \sum\limits_{i = 1}^{m} (f_{\mathbf{w},b}(\mathbf{x}^{(i)}) - y^{(i)})x_{j}^{(i)} + \frac{\lambda}{m} w_j \\ \frac{\partial J(\mathbf{w},b)}{\partial b} &= \frac{1}{m} \sum\limits_{i = 1}^{m} (f_{\mathbf{w},b}(\mathbf{x}^{(i)}) - y^{(i)}) \end{aligned} wj:b:∂wj∂J(w,b)∂b∂J(w,b)=wj−α∂wj∂J(w,b)=b−α∂b∂J(w,b)=m1i=1∑m(fw,b(x(i))−y(i))xj(i)+mλwj=m1i=1∑m(fw,b(x(i))−y(i)) | 标量 |
梯度下降(矩阵形式) | w : = w − α ∂ J ( w , b ) ∂ w ∂ J ( w , b ) ∂ w = 1 m X T ( f w , b ( X ) − y ) + λ m w \begin{aligned}\mathbf w:&=\mathbf w-\alpha\frac{\partial J(\mathbf{w},b)}{\partial \mathbf{w}}\\ \frac{\partial J(\mathbf{w},b)}{\partial \mathbf{w}}&=\frac{1}{m}\mathbf X^T(f_{w,b}(\mathbf X) -\mathbf y)+\frac{\lambda}{m} \mathbf w\end{aligned} w:∂w∂J(w,b)=w−α∂w∂J(w,b)=m1XT(fw,b(X)−y)+mλw | n × 1 n\times 1 n×1 |
numpy实现
def zscore_normalize_features(X):
mu=np.mean(X,axis=0)
sigma=np.std(X,axis=0)
X_norm=(X-mu)/sigma
return X_norm
# f_wb
def compute_f_wb(X,w,b):
f_wb=np.dot(X,w)+b # m*1
return f_wb
# j_wb
def compute_cost(X,y,w,b,lambda_,f_wb_function):
m,n=X.shape
f_wb=f_wb_function(X,w,b) # m*1
j_wb=1/(2*m)*np.dot((f_wb-y).T, f_wb-y)+(lambda_/2*m)*np.dot(w.T,w)# 1*1
j_wb=j_wb[0,0] # scalar
return j_wb
# dj_dw,dj_db
def compute_gradient(X, y, w, b, lambda_,f_wb_function):
m,n=X.shape
f_wb=f_wb_function(X,w,b) # m*1
dj_dw=(1/m)*np.dot(X.T,(f_wb-y))+(lambda_/n)*w # n*1
dj_db=(1/m)*np.sum(f_wb-y) # scalar
return dj_dw,dj_db
# w,b,j_history,w_history
def gradient_descent(X, y, w, b, cost_function, gradient_function, f_wb_function,alpha, num_iters,lambda_):
J_history = []
w_history = []
w_temp = copy.deepcopy(w)
b_temp = b
for i in range(num_iters):
dj_dw, dj_db = gradient_function(X, y, w_temp,b_temp,lambda_,f_wb_function)
w_temp = w_temp - alpha * dj_dw
b_temp = b_temp - alpha * dj_db
cost = cost_function(X, y, w_temp, b_temp,lambda_,f_wb_function)
J_history.append(cost)
return w_temp, b_temp, J_history, w_history
样本点
x = np.arange(0, 20, 1)
y = 1+x+2*pow(x,2)+3*pow(x,3)
fig=go.Figure()
fig.update_layout(width=1000,height=618)
fig.add_trace(
go.Scatter(
x=x,
y=y,
name="样本点",
mode="markers"
)
)
fig.show()
特征
feature_1=x # (m,)
feature_2=pow(x,2) # (m,)
feature_3=pow(x,3) # (m,)
x_=zscore_normalize_features(np.transpose(np.array([
feature_1,
feature_2,
feature_3
])))
y_=y.reshape(-1,1)
x_.shape,y_.shape
梯度下降
m,n=x_.shape
initial_w = np.zeros((n,1))
initial_b = 0
iterations = 150
alpha = 0.5
lambda_=0
w,b,J_history,w_history = gradient_descent(x_ ,y_, initial_w, initial_b, compute_cost, compute_gradient, compute_f_wb,alpha, iterations,lambda_)
fig=go.Figure()
fig.update_layout(width=1000,height=618)
fig.add_trace(
go.Scatter(
x=np.arange(1,iterations+1),
y=J_history,
name="学习曲线",
mode="markers+lines"
)
)
fig.update_layout(
xaxis_title="迭代次数",
yaxis_title="J_wb"
)
fig.show()
预测
y_hat=compute_f_wb(x_,w,b) # m*1
fig=go.Figure()
fig.update_layout(width=1000,height=618)
fig.add_trace(
go.Scatter(
x=x,
y=y,
name="y",
mode="markers"
)
)
fig.add_trace(
go.Scatter(
x=x,
y=y_hat.ravel(),
name="y_hat"
)
)
fig.update_layout(
xaxis_title="x",
yaxis_title="y"
)
fig.show()
边栏推荐
- 使用 Abp.Zero 搭建第三方登录模块(四):微信小程序开发
- Bigder:37/100 一个误操作
- 秒杀实现图
- [network counting] (III) hypernetwork, routing, NAT protocol
- MySQL 增删改查(进阶)
- Allegro如何导入高清Logo、二维码、防静电标识等图片以及汉字
- 深度学习(二)一文带你了解神经网络,激活函数
- Bigder: common business terms in 36/100 report testing
- What is the difference between win11 beta 22621.436 and 22622.436?
- Boss直聘怎么写出优秀的简历?
猜你喜欢
UART communication experiment (query mode)
Zabbix5.0.8-ODBC监控oracle11g
[database] addition, deletion, modification and query of MySQL table (Advanced)
Critical path problem
Focus on the "double five" project, directly hit the front line of the project - Xiangjiang new area, and rise at the top of the industry
VR全景在各行各业是如何展示?如何落地应用的?
How does win11 run as an administrator? Win11 setup method running as Administrator
Win11 Beta 22621.436和22622.436有什么区别?
聚焦双五工程,直击项目一线丨湘江新区,在产业尖端崛起
Win11怎么以管理员身份运行?Win11以管理员身份运行的设置方法
随机推荐
Zabbix5.0.8-ODBC监控oracle11g
【数据库】MySQL表的增删改查(进阶)
所有navicat版本都支持MySQL吗,为什么我打开连接不上呢?
MySQL join和索引
MySQL join and index
日期类的理解学习
Can flick SQL query Clickhouse
MySQL series article 4: execution plan
Physical layer of network
DOM operation of JS -- event proxy
[Digital IC] understand Axi protocol in simple terms
毕业985,工作996,也躲不开中年危机
深度学习(二)一文带你了解神经网络,激活函数
How to make an appointment while watching the panorama? Here comes the VR catering system tutorial
【如何优化她】教你如何定位不合理的SQL?并优化她~~~
MySQL系列三:函数&索引&视图&错误代码编号含义
网络之物理层
subprocess
Real time synchronization and conversion of massive data based on Flink CDC
MySQL addition, deletion, modification and query (Advanced)