- 現(xiàn)代決策樹模型及其編程實踐:從傳統(tǒng)決策樹到深度決策樹
- 黃智瀕編著
- 1755字
- 2022-08-12 16:11:30
2.2.7 CART回歸決策樹的編程實踐
在CART分類決策樹中使用基尼系數(shù)作為尋找最優(yōu)劃分點的依據(jù),在回歸樹中則采用均方誤差最小化準則作為特征和分割點的選擇方法,下面我們基于這種算法來實現(xiàn)回歸樹的模型。
2.2.7.1 整體流程
首先介紹一下整體流程,如代碼段2.14所示。與CART分類樹類似,主要由四部分組成:數(shù)據(jù)集加載、模型訓(xùn)練、模型預(yù)測和決策樹可視化。
代碼段2.14 CART回歸樹測試主程序(源碼位于Chapter02/test_CartRegressor.py)


1. 數(shù)據(jù)集加載(第15~29行)
數(shù)據(jù)集使用的是表2.11中的“流行歌手喜好度”數(shù)據(jù)集。與2.2.3節(jié)的例子類似,首先在第18~19行利用“with open”語法和csv庫讀取數(shù)據(jù)集文件,并將其轉(zhuǎn)換成list類型。與之不同的是,接下來第20~24行針對數(shù)據(jù)集的第2列執(zhí)行數(shù)據(jù)預(yù)處理,將非數(shù)值型字符串(“性別”一列的屬性值)轉(zhuǎn)化成數(shù)字。最后在第25~29行進行數(shù)據(jù)集的劃分和數(shù)據(jù)類型的轉(zhuǎn)換。
2. 決策樹模型的訓(xùn)練和生成(第31~38行)
在CART回歸樹的創(chuàng)建和訓(xùn)練過程中,與CART分類樹相比,回歸樹使用CartRegressor代替CartClassifier,并且指定了浮點型切分點保留的小數(shù)點后有效數(shù)字位數(shù)。訓(xùn)練過程對外提供的接口與CART分類樹相同,但是其內(nèi)部的訓(xùn)練細節(jié)會有所差異,在下文中會對此做詳細描述。我們先來看一下使用上述數(shù)據(jù)集訓(xùn)練得到的決策樹模型。

該model變量實際上是由一組規(guī)則表示的。以上輸出結(jié)果為決策樹的字典(樹形結(jié)構(gòu))數(shù)據(jù)結(jié)構(gòu)形式,在這棵model樹中,從根節(jié)點到每個葉子節(jié)點的每條路徑都代表一條規(guī)則。為了更清晰地表示規(guī)則,我們可以將以上數(shù)據(jù)結(jié)構(gòu)轉(zhuǎn)換成“if-then”的格式,如下所示:


由此可以看出,CART回歸決策樹與一組“if-then”規(guī)則是等價的。
3. 決策樹模型的使用(第40~45行)
在第42~44行的模型預(yù)測階段,調(diào)用CartRegressor類的成員函數(shù)predict,傳入測試集數(shù)據(jù)X_test,返回numpy.array類型的預(yù)測結(jié)果y_pred,并且打印輸出測試集的真實值y_test和預(yù)測值y_pred。
在第45行的模型評估階段,調(diào)用sklearn的r2_score函數(shù)計算R2指標。R2指標是用于評估回歸問題預(yù)測性能的一種指標,R2指標越大,代表預(yù)測性能越好。在這里調(diào)用r2_score函數(shù)前,需要使用“from sklearn.metrics import r2_score”語句將其引入當(dāng)前環(huán)境。實際執(zhí)行結(jié)果如下:

4. 決策樹可視化(第47~49行)
決策樹可視化階段導(dǎo)入了tree_plotter包的tree_plot函數(shù),關(guān)于tree_plotter包的詳細介紹可以回看2.2.3節(jié)。在tree_plot函數(shù)中傳入訓(xùn)練好的模型,底層借助Matplotlib進行可視化,效果如圖2.19所示。

圖2.19 “流行歌手喜好度”數(shù)據(jù)集生成的CART回歸樹
2.2.7.2 訓(xùn)練和創(chuàng)建過程
下面展開介紹CART回歸樹的訓(xùn)練和創(chuàng)建過程。CartRegressor的創(chuàng)建和訓(xùn)練過程與CartClassifier類似。不同點在于CartRegressor去掉了建立字符串與數(shù)值映射的功能,并且將數(shù)據(jù)預(yù)處理部分移到類外部定制。另外,最重要的區(qū)別在于模型訓(xùn)練時切分點的選取。接下來,我們逐一分析這些不同點在CART回歸樹的訓(xùn)練和創(chuàng)建過程中的表現(xiàn)。
首先介紹一下CartRegressor類的構(gòu)造。從代碼段2.15中可以看到,CartRegressor類的實現(xiàn)依然依賴于torch和numpy。在構(gòu)造函數(shù)__init__中,依然需要提供use_gpu和min_samples_split兩個參數(shù)。與CartClassifier類不同的是,新增加了bit參數(shù),bit用來表示連續(xù)屬性離散化時精確的小數(shù)點位數(shù),默認保留2位。另外,CartRegressor類刪掉了CartClassifier類中用于建立字符串與數(shù)值映射的成員變量,讀者可以參照CartClassifier類對比學(xué)習(xí)。
接下來介紹CART回歸樹與分類樹在訓(xùn)練過程中的不同,如代碼段2.15和2.16所示。在函數(shù)__create_tree中,主要有3處與分類樹不同。第一處在第87~89行,當(dāng)滿足遞歸終止條件“節(jié)點樣本數(shù)小于self.min_samples_split”時,返回的預(yù)測值是該集合中所有目標變量的平均值。第二處在第91~93行,差異集中在__choose_best_point_to_split函數(shù)中,在回歸樹中采用“平方誤差最小”的原則來選擇最優(yōu)切分點,該部分內(nèi)容稍后做重點講解。第三處在第95~105行,使用最優(yōu)屬性和最優(yōu)切分點劃分數(shù)據(jù)集時相較分類樹(處理匹配字符串“<=”和“>”的代碼邏輯)做略微調(diào)整。
代碼段2.15 CART分類樹創(chuàng)建過程(源碼位于Chapter02/CartRegressor.py)

代碼段2.16 創(chuàng)建CART回歸樹的核心代碼(源碼位于Chapter02/CartRegressor.py)


__choose_best_point_to_split函數(shù)如代碼段2.17所示。在第132~155行遍歷所有屬性值時,回歸樹中不再計算基尼不純度和基尼增益,而是針對回歸問題計算損失函數(shù)。其中,第144行和第147行分別計算了使用當(dāng)前切分點劃分的左右子樹的殘差平方和,第149行計算左右子樹的總殘差平方和。最后選出取得最小損失函數(shù)的切分點和屬性索引,作為最優(yōu)切分點和最優(yōu)分裂屬性。對比2.2.3節(jié)計算基尼增益的方法,此處計算最小平方誤差的方法具有異曲同工之妙。
代碼段2.17 選擇最優(yōu)切分點(源碼位于Chapter02/CartRegressor.py)


最后是CART回歸樹的預(yù)測過程和可視化過程。由于CART回歸樹與分類樹的預(yù)測過程和可視化過程幾乎完全相同,在此不做贅述,請讀者參考2.2.3節(jié)CART分類樹的預(yù)測和可視化代碼。
以上即為CART回歸樹針對2.2.6節(jié)的“流行歌手喜好度”數(shù)據(jù)集進行編程實踐的全部過程。