为了账号安全,请及时绑定邮箱和手机立即绑定

线性回归基础代码

标签:
机器学习
# Use linear model to model this data.
from sklearn.linear_model import LinearRegression
import numpy as np

lr=LinearRegression()
lr.fit(pga.distance[:,np.newaxis],pga['accuracy'])  # Another way is using pga[['distance']]
theta0=lr.intercept_
theta1=lr.coef_
print(theta0)
print(theta1)


#calculating cost-function for each theta1
#计算平均累积误差
def cost(x,y,theta0,theta1):
    J=0
    for i in range(len(x)):
        mse=(x[i]*theta1+theta0-y[i])**2
        J+=mse
    return J/(2*len(x))

theta0=100
theta1s = np.linspace(-3,2,197)
costs=[]
for theta1 in theta1s:
    costs.append(cost(pga['distance'],pga['accuracy'],theta0,theta1))
plt.plot(theta1s,costs)
plt.show()
print(pga.distance)


#调整theta
def partial_cost_theta0(x,y,theta0,theta1):
    #我们的模型是线性拟合函数时:y=theta1*x + theta0,而不是sigmoid函数,当非线性时我们可以用sigmoid
    #直接多整个x series操作,省的一个一个计算,最终求sum 再平均
    h=theta1*x+theta0  
    diff=(h-y)
    partial=diff.sum()/len(diff)
    return partial
partial0=partial_cost_theta0(pga.distance,pga.accuracy,1,1)

def partial_cost_theta1(x,y,theta0,theta1):
    #我们的模型是线性拟合函数:y=theta1*x + theta0,而不是sigmoid函数,当非线性时我们可以用sigmoid
    h=theta1*x+theta0
    diff=(h-y)*x
    partial=diff.sum()/len(diff)
    return partial
partial1=partial_cost_theta1(pga.distance,pga.accuracy,0,5)
print(partial0)
print(partial1)


def gradient_descent(x,y,alpha=0.1,theta0=0,theta1=0):  #设置默认参数
    #计算成本
    #调整权值
    #计算错误代价,判断是否收敛或者达到最大迭代次数
    most_iterations=1000
    convergence_thres=0.000001 
   
    c=cost(x,y,theta0,theta1)
    costs=[c]
    cost_pre=c+convergence_thres+1.0
    
    counter=0
    while( (np.abs(c-cost_pre)>convergence_thres) & (counter<most_iterations) ):
        update0=alpha*partial_cost_theta0(x,y,theta0,theta1)
        update1=alpha*partial_cost_theta1(x,y,theta0,theta1)
        
        theta0-=update0
        theta1-=update1

        cost_pre=c
        c=cost(x,y,theta0,theta1)
        costs.append(c)
        counter+=1
    return  {'theta0': theta0, 'theta1': theta1, "costs": costs}

print("Theta1 =", gradient_descent(pga.distance, pga.accuracy)['theta1'])
costs=gradient_descent(pga.distance,pga.accuracy,alpha=.01)['cost']
print(gradient_descent(pga.distance, pga.accuracy,alpha=.01)['theta1'])
plt.scatter(range(len(costs)),costs)
plt.show()

数据集 :
复制下面数据,保存为: pga.csv

distance,accuracy
290.3,59.5
302.1,54.7
287.1,62.4
282.7,65.4
299.1,52.8
300.2,51.1
300.9,58.3
279.5,73.9
287.8,67.6
284.7,67.2
296.7,60
283.3,59.4
284,72.2
292,62.1
282.6,66.5
287.9,60.9
279.2,67.3
291.7,64.8
289.9,58.1
289.8,61.7
298.8,56.4
280.8,60.5
294.9,57.5
287.5,61.8
282.7,56
277.7,72.5
270.5,71.7
285.2,66
315.1,55.2
281.9,67.6
293.3,58.2
286,59.9
285.6,58.2
289.9,65.7
277.5,59
293.6,56.8
301.1,65.4
300.8,63.4
287.4,67.3
281.8,72.6
277.4,63.1
279.1,66.5
287.4,66.4
280.9,62.3
287.8,57.2
261.4,69.2
272.6,69.4
291.3,65.3
294.2,52.8
285.5,49
287.9,61.1
282.2,65.6
301.3,58.2
276.2,61.7
281.6,68.1
275.5,61.2
309.7,53.1
287.7,56.4
291.6,56.9
284.1,65
299.6,57.5
282.7,60
271.5,72
292.1,58.2
295,59.4
274.9,69
273.6,68.7
299.9,60.1
279.9,74
289.9,66
283.6,59.8
310.3,52.4
291.7,65.6
284.2,63.2
295,53.5
298.6,55.1
297.4,60.4
299.7,67.7
284.4,69.7
286.4,72.4
285.9,66.9
297.6,54.3
272.5,62
277,66.2
287.6,60.9
280.4,69.4
280,63.7
295.4,52.8
274.4,68.8
286.5,73.1
287.7,65.2
291.5,65.9
279,69.4
299,65.2
290.1,69.1
288.9,67.9
288.8,68.2
283.2,61
293.2,58.4
285.3,67.3
284.1,65.7
281.4,67.7
286.1,61.4
284.9,62.3
284.8,68.1
296,62
282.9,71.8
280.9,67.8
291.2,62
292.8,62.2
291,61.9
285.7,62.4
283.9,62.9
298.4,61.5
285.1,65.3
286.1,60.1
283.1,65.4
289.4,58.3
284.6,70.7
296.6,62.3
295.9,64.9
295.2,62.8
293.9,54.5
275,65.5
286.8,69.5
291.1,64.4
284.8,62.5
283.7,59.5
295.4,66.9
291.8,62.7
274.9,72.3
302.9,61.2
272.1,80.4
274.9,74.9
296.3,59.4
286.2,58.8
294.2,63.3
284.1,66.5
299.2,62.4
275.4,71
273.2,70.9
281.6,65.9
295.7,55.3
287.1,56.8
287.7,66.9
296.7,53.7
282.2,64.2
291.7,65.6
281.6,73.4
311,56.2
278.6,64.7
288,65.7
276.7,72.1
292,62
286.4,69.9
292.7,65.7
294.2,62.9
278.6,59.6
283.1,69.2
284.1,66
278.6,73.6
291.1,60.4
294.6,59.4
274.3,70.5
274,57.1
283.8,62.7
272.7,66.9
303.2,58.3
282,70.4
281.9,61
287,59.9
293.5,63.8
283.6,56.3
296.9,55.3
290.9,58.2
303,58.1
292.8,61.1
281.1,65
293,61.1
284,66.5
279.8,66.7
292.9,65.4
284,66.9
282,64.5
280.6,64
287.7,63.4
287.7,63.4
298.3,59.5
299.6,53.4
291.3,62.5
295.2,61.4
288,62.4
297.8,59.5
286,62.6
285.3,66.2
286.9,63.4
275.1,73.7
点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消