简单线性回归

线性回归算法(Liner Regression),线性回归与上一节的KNN不同,KNN主要解决分类问题,而LR主要解决回归问题。本篇介绍简单回归算法(SimpleLinearRegression)。

思想简单,容易实现。
许多强大的非线性模型的基础。
结果具有很好的可解释性。
蕴含机器学习中很多重要的思想。

简单线性回归

样本特征只有一个,称为:简单线性回归。
例如房屋和价格的关系:
avatar
通过每个点,我们需要找到一条直线,最大程度的“拟合”样本特征和样本输出标记之间的关系:
avatar
所以我们希望真值和预测值之间的差距尽量小:
avatar

至此,我们可以得到一类机器学习的基本思路

损失函数(loss)尽可能小
效用函数(utility function)尽可能大
avatar
avatar
在简单线性回归中,应用最小二乘法,求得a,和b的值,最小二乘法推导过程见上一篇博客:
avatar

python实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import numpy as np
import matplotlib.pyplot as plt
#设置
x = np.array([1. ,2. ,3. ,4., 5.])
y = np.array([1., 3., 2., 3., 5.])
##绘图
plt.scatter(x,y)
plt.axis([0, 6, 0, 6])
plt.show()
##接下来按照公式求解即可
x_mean = np.mean(x)
y_mean = np.mean(y)
num = 0.0
d = 0.0
for x_i, y_i in zip(x, y):
num += (x_i - x_mean) * (y_i - y_mean)
d += (x_i - x_mean) ** 2
##所解值
a = num / d
b = y_mean -a * x_mean
y_hat = a * x + b
##绘图
plt.scatter(x, y)
plt.plot(x, y_hat, color='r')
plt.axis([0, 6, 0, 6])
plt.show()