为了账号安全,请及时绑定邮箱和手机立即绑定

【机器学习】线性回归——单变量梯度下降的实现(Python版)

【线性回归】

如果要用一句话来解释线性回归是什么的话,那么我的理解是这样子的:**线性回归,是从大量的数据中找出最优的线性(y=ax+b)拟合函数,通过数据确定函数中的未知参数,进而进行后续操作(预测)**回归的概念是从统计学的角度得出的,用抽样数据去预估整体(回归中,是通过数据去确定参数),然后再从确定的函数去预测样本。

【损失函数】

用线性函数去拟合数据,那么问题来了,到底什么样子的函数最能表现样本?对于这个问题,自然而然便引出了损失函数的概念,损失函数是一个用来评价样本数据与目标函数(此处为线性函数)拟合程度的一个指标。我们假设,线性函数模型为:
图片描述
基于此函数模型,我们定义损失函数为:
图片描述
从上式中我们不难看出,损失函数是一个累加和(统计量)用来记录预测值与真实值之间的1/2方差,从方差的概念我们知道,方差越小说明拟合的越好。那么此问题进而演变称为求解损失函数最小值的问题,因为我们要通过样本来确定线性函数的中的参数θ_0和θ_1.

【梯度下降】

梯度下降算法是求解最小值的一种方法,但并不是唯一的方法。梯度下降法的核心思想就是对损失函数求偏导,从随机值(任一初始值)开始,沿着梯度下降的方向对θ_0和θ_1的迭代,最终确定θ_0和θ_1的值,注意,这里要同时迭代θ_0和θ_1(这一点在编程过程中很重要),具体迭代过程如下:图片描述

【Python代码实现】

#此处数据集,采用吴恩达第一次作业的数据集:ex1data1.txt
import numpy as np
import matplotlib.pyplot as plt


#读取数据
def readData(path):
    data = np.loadtxt(path,dtype=float,delimiter=",")
    return data
#损失函数,返回损失函数计算结果 
def costFunction(theta_0,theta_1,x,y,m):
    predictValue = theta_0+theta_1*x
    return sum((predictValue-y)**2)/(2*m)
    
#梯度下降算法
#data:数据
#theta_0、theta_1:参数θ_0、θ_1
#iterations:迭代次数
#alpha:步长(学习率)
def gradientDescent(data,theta_0,theta_1,iterations,alpha):
    eachIterationValue = np.zeros((iterations,1))
    x = data[:,0]
    y = data[:,1]
    m =data.shape[0]
    for i in range(0,iterations):
        hypothesis = theta_0+theta_1*x
        temp_0 = theta_0-alpha*((1/m)*sum(hypothesis-y))
        temp_1 = theta_1-alpha*(1/m)*sum((hypothesis-y)*x)
        theta_0 = temp_0
        theta_1 = temp_1
        costFunction_temp = costFunction(theta_0,theta_1,x,y,m)
        eachIterationValue[i,0] =costFunction_temp
    return theta_0,theta_1,eachIterationValue


if __name__ == "__main__":
   data = readData('ex1data1.txt')
   iterations=1500
   plt.scatter(data[:,0],data[:,1],color='g',s=20)
   theta_0,theta_1,eachIterationValue =gradientDescent(data,0,0,iterations,0.01)
   hypothesis = theta_0+theta_1*data[:,0]
   plt.plot(data[:,0],hypothesis)
   plt.title('Fittingcurve')
   plt.show()
   plt.plot(np.arange(iterations),eachIterationValue)
   plt.title('CostFunction')
   plt.show()

结果如下:
图片描述
图片描述

点击查看更多内容
7人点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消