博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Coursera机器学习编程作业Python实现(Andrew Ng)—— 1.1 Linear regression with one variable...
阅读量:5943 次
发布时间:2019-06-19

本文共 1793 字,大约阅读时间需要 5 分钟。

1.1 Linear regression with one variable

import numpy as npimport matplotlib.pyplot as pltdata1 = np.loadtxt('ex1data1.txt', delimiter=',')

Plotting the Data

plt.scatter(data1[:,0], data1[:,1], c='red', marker='x')plt.xlabel('Population of City in 10,000s')plt.ylabel('Profit in $10,000s')plt.show()

数据预处理及参数初始化

x0 = np.ones((len(data1),1))x1 = data1[:,0]x1 = x1.reshape([len(x1), 1])X = np.hstack((x0, x1)) y = data1[:,1] y = y.reshape([len(y), 1]) theta = np.zeros((2,1)) iterations = 1500alpha = 0.01

定义假设函数

def h(X, theta):    return np.dot(X, theta)

定义代价函数

def computeCost(X, theta, y):    return 0.5 * np.mean(np.square(h(X, theta) - y))

定义梯度下降函数

def gradientDescent(X, theta, y, iterations, alpha):    Cost = []    Cost.append(computeCost(X, theta, y))    for i in range(iterations):        grad0 = np.mean(h(X, theta) - y)        grad1 = np.mean((h(X, theta) - y) * (X[:,1].reshape([len(X), 1])))        theta[0] = theta[0] - alpha * grad0        theta[1] = theta[1] - alpha * grad1        Cost.append(computeCost(X, theta, y))    return theta, Cost

运行并观察结果

theta_result, Cost_result = gradientDescent(X, theta, y, iterations, alpha) theta_result
array([[-3.63029144],       [ 1.16636235]])
predict1 = np.dot(np.array([1, 3.5]), theta_result) predict1
array([0.45197679])
x_predict = [X[:,1].min(), X[:,1].max()]y_predict = [theta_result[0]+theta_result[1]*(X[:,1].min()), theta_result[0]+theta_result[1]*(X[:,1].max())]plt.plot(x_predict, y_predict, c='blue', label='predict')plt.scatter(data1[:,0], data1[:,1], c='red', marker='x', label='train_data')plt.xlabel('Population of City in 10,000s')plt.ylabel('Profit in $10,000s')plt.legend()plt.show()
plt.plot(Cost_result)plt.xlabel('Iterations')plt.ylabel('Cost')plt.show()

转载于:https://www.cnblogs.com/shouzhenghouchuqi/p/10585924.html

你可能感兴趣的文章
ADO.NET复习——自己编写SqlHelper类
查看>>
库函数strlen源码重现及注意问题
查看>>
《实例化需求》读书笔记
查看>>
常用Java8语法小结
查看>>
ZJOI2019 Day2 游记
查看>>
ccf题库中2015年12月2号消除类游戏
查看>>
WinForm窗体间如何传值
查看>>
Ado.Net 连接数据库
查看>>
java多线程系列1:Sychronized关键字
查看>>
解释性的语言vs编译性语言
查看>>
20155222 2016-2017-2 《Java程序设计》第10周学习总结
查看>>
MapReduce1.x与MapReduce2.x差异
查看>>
PHP array_key_exists() 函数(判断某个数组中是否存在指定的 key)
查看>>
Charpter5 软件测试总结
查看>>
python中@staticmethod、@classmethod和实例方法
查看>>
Java创建数组的三种方法
查看>>
管理计算机内存
查看>>
some requirement checks failed
查看>>
存储管理
查看>>
HDU-2089-不要62
查看>>