- Python機器學習算法: 原理、實現與案例
- 劉碩
- 479字
- 2020-01-20 15:06:23
1.4 算法實現
1.4.1 最小二乘法
首先,我們基于最小二乘法實現線性回歸,代碼如下:

此段代碼十分簡單,簡要說明如下。
- _ols()方法:最小二乘法的實現,即
。
- _preprocess_data_X()方法:對
進行預處理,添加
列并設置為1。
- train()方法:訓練模型,調用_ols()方法估算模型參數
,并保存。
- predict()方法:預測,實現函數
,對
中每個實例進行預測。
1.4.2 梯度下降
接下來,我們基于(批量)梯度下降實現線性回歸,代碼如下:



上述代碼簡要說明如下(詳細內容參看代碼注釋)。
- __init__()方法:構造器(也稱為構造函數),保存用戶傳入的超參數。
- _predict()方法:預測的內部接口,實現函數
。
- _loss()方法:實現損失函數
,計算當前
下的損失,該方法有以下兩個用途。
- ◆ 供早期停止法使用:如果用戶通過超參數tol啟用早期停止法,則調用該方法計算損失。
- ◆ 方便調試:迭代過程中可以每次打印出當前損失,觀察變化的情況。
- _gradient()方法:計算當前梯度
。
- _gradient_descent()方法:實現批量梯度下降算法。
- _preprocess_data_X()方法:對
進行預處理,添加
列并設置為1。
- train()方法:訓練模型。該方法由3部分構成:
- ◆ 對訓練集的X_train進行預處理,添加
列并設置為1。
- ◆ 初始化模型參數
,賦值較小的隨機數。
- ◆ 調用_gradient_descent()方法訓練模型參數
。
- ◆ 對訓練集的X_train進行預處理,添加
- predict()方法:預測。內部調用_predict()方法對
中每個實例進行預測。
推薦閱讀
- ReSharper Essentials
- Moodle Administration Essentials
- JavaScript語言精髓與編程實踐(第3版)
- C語言程序設計(第2版)
- DevOps入門與實踐
- Learning Selenium Testing Tools(Third Edition)
- EPLAN實戰設計
- Functional Kotlin
- Unity 5.x By Example
- PHP與MySQL權威指南
- Mastering Gephi Network Visualization
- Arduino電子設計實戰指南:零基礎篇
- 算法精解:C語言描述
- Mastering ASP.NET Web API
- Java面試一戰到底(基礎卷)