- 現(xiàn)代決策樹(shù)模型及其編程實(shí)踐:從傳統(tǒng)決策樹(shù)到深度決策樹(shù)
- 黃智瀕編著
- 4077字
- 2022-08-12 16:11:28
2.2.3 CART分類決策樹(shù)的編程實(shí)踐
針對(duì)2.2.2節(jié)的天氣與是否打網(wǎng)球的數(shù)據(jù)集(PlayTennis數(shù)據(jù)集),我們利用Python和PyTorch編碼展示CART分類樹(shù)模型的細(xì)節(jié)。
2.2.3.1 整體流程
首先介紹整體流程,如代碼段2.1所示。程序主要由四部分組成:數(shù)據(jù)集加載、模型訓(xùn)練、模型預(yù)測(cè)和決策樹(shù)可視化。
代碼段2.1 CART分類樹(shù)測(cè)試主程序(源碼位于Chapter02/test_CartClassifier.py)


1. 數(shù)據(jù)集加載(第14~23行)
首先,第17行使用“with open as”語(yǔ)法打開(kāi)指定數(shù)據(jù)文件,并獲取文件句柄f。該語(yǔ)法會(huì)在執(zhí)行完畢“with open”作用域內(nèi)的代碼后自動(dòng)關(guān)閉數(shù)據(jù)文件。open函數(shù)的第一個(gè)參數(shù)為數(shù)據(jù)文件的路徑;第二個(gè)參數(shù)為文件打開(kāi)方式,“r”代表以只讀方式打開(kāi);encoding參數(shù)指定讀取文件的編碼格式,“gbk”代表使用GBK編碼。
然后,第18行使用csv庫(kù)讀取數(shù)據(jù)文件內(nèi)容并將其轉(zhuǎn)換成Python內(nèi)置的list類型。使用csv庫(kù)需要在代碼片段開(kāi)頭添加“import csv”語(yǔ)句。讀取數(shù)據(jù)時(shí)使用csv庫(kù)的reader函數(shù),參數(shù)傳入“with open”獲取的文件句柄f。
最后,第19~23行利用Python切片和列表生成式將原始數(shù)據(jù)集分割成屬性名列表feature_names、目標(biāo)變量名y_name、屬性集X、目標(biāo)變量集y。由于本例是解決數(shù)據(jù)集比較小的分類問(wèn)題,因此令訓(xùn)練集(X_train和y_train)和測(cè)試集(X_test和y_test)使用相同的數(shù)據(jù)集(X和y),而通常的做法是將原始數(shù)據(jù)集按8:1:1或6:2:2的比例劃分成訓(xùn)練集、測(cè)試集和驗(yàn)證集。另外,為了便于使用PyTorch進(jìn)行GPU加速,我們將數(shù)據(jù)集從list類型進(jìn)一步轉(zhuǎn)換為numpy類型(PyTorch內(nèi)部集成了numpy與Tensor的快速轉(zhuǎn)換方法)。同樣,使用numpy庫(kù)也需要導(dǎo)入相應(yīng)的包,語(yǔ)句為“import numpy as np”。
2. 決策樹(shù)模型的訓(xùn)練和生成(第25~32行)
CART分類樹(shù)的創(chuàng)建和訓(xùn)練過(guò)程被封裝成CartClassifier類,通過(guò)使用“from cart import CartClassifier”將其導(dǎo)入當(dāng)前環(huán)境。
在第27行創(chuàng)建決策樹(shù)的過(guò)程中,觸發(fā)CartClassifier類的構(gòu)造函數(shù)。在這里,設(shè)置use_gpu參數(shù)為T(mén)rue,代表啟用GPU加速。此外,該構(gòu)造函數(shù)還可以傳入其他參數(shù),后面的內(nèi)容將對(duì)其進(jìn)行展開(kāi)介紹。
在第31行CART分類樹(shù)的訓(xùn)練過(guò)程中,調(diào)用CartClassifier類的train成員函數(shù),傳入訓(xùn)練集數(shù)據(jù)X_train、y_train和feature_names。訓(xùn)練完成后,返回模型數(shù)據(jù)如下:

該model變量實(shí)際上是由一組規(guī)則表示的。以上輸出結(jié)果為決策樹(shù)的字典(樹(shù)形結(jié)構(gòu))數(shù)據(jù)結(jié)構(gòu)形式,在這棵model樹(shù)中,從根節(jié)點(diǎn)到每個(gè)葉子節(jié)點(diǎn)的每條路徑都代表一條規(guī)則。為了更清晰地表示規(guī)則,我們可以將以上數(shù)據(jù)結(jié)構(gòu)轉(zhuǎn)換成“if-then”的格式,如下所示:

事實(shí)上,CART分類決策樹(shù)與一組“if-then”規(guī)則是等價(jià)的。
3. 決策樹(shù)模型的使用(第34~41行)
在第36~38行的模型預(yù)測(cè)階段,調(diào)用CartClassifier類的成員函數(shù)predict,傳入測(cè)試集數(shù)據(jù)X_test,返回numpy.array類型的預(yù)測(cè)結(jié)果y_pred,并且打印輸出測(cè)試集的真實(shí)值y_test和預(yù)測(cè)值y_pred。
在第39~41行的模型評(píng)估階段,首先使用Python的列表生成式生成測(cè)試集中預(yù)測(cè)值與真實(shí)值相等的元素,每種相等的情況用int型變量1表示。然后使用numpy的sum函數(shù)對(duì)上述列表求和,統(tǒng)計(jì)出預(yù)測(cè)正確的計(jì)數(shù)。最后打印預(yù)測(cè)正確的樣本計(jì)數(shù)、總樣本計(jì)數(shù)以及預(yù)測(cè)準(zhǔn)確率。實(shí)際執(zhí)行結(jié)果如下:

4. 決策樹(shù)的可視化(第43~45行)
決策樹(shù)可視化階段使用了tree_plotter包的tree_plot函數(shù),傳入前面訓(xùn)練好的模型model。tree_plotter包是我們使用Matplotlib自定義的決策樹(shù)繪圖包,在隨后的內(nèi)容中我們將詳細(xì)介紹,在此先展示一下可視化的效果,如圖2.9所示。
2.2.3.2 訓(xùn)練和創(chuàng)建過(guò)程
首先介紹用到的構(gòu)造函數(shù),見(jiàn)代碼段2.2。從第8~9行可以看到,CartClassifier類的實(shí)現(xiàn)依賴于torch和numpy。
在CartClassifier類的構(gòu)造函數(shù)__init__中,需要提供use_gpu和min_samples_split兩個(gè)參數(shù)。其中,use_gpu是一個(gè)布爾值,代表該類是否啟用GPU加速,默認(rèn)為False,代表使用CPU;min_samples_split是一個(gè)整型數(shù),代表決策樹(shù)分裂完成后葉子節(jié)點(diǎn)的最少樣本數(shù),默認(rèn)值為1,代表樹(shù)完全分裂。

圖2.9 PlayTennis數(shù)據(jù)集生成的CART分類樹(shù)
代碼段2.2 CART分類樹(shù)的構(gòu)造函數(shù)(源碼位于Chapter02/CartClassifier.py)

另外,構(gòu)造函數(shù)中還維護(hù)一系列類的核心變量。其中,self.tree為存儲(chǔ)樹(shù)模型的核心結(jié)構(gòu),初始情況下為空dict;self.feature_names為存儲(chǔ)數(shù)據(jù)集屬性名的numpy數(shù)組;self.str_map、self.num_map、self.x_use_map、self.y_use_map為字符串與數(shù)值之間的映射器和開(kāi)關(guān),用于將numpy數(shù)組中的字符串類型映射成數(shù)值類型,以兼容PyTorch,與之相關(guān)的函數(shù)接口為_(kāi)_deal_value_map和__get_value,在后文中將逐一介紹。
接下來(lái)介紹CART分類樹(shù)的訓(xùn)練函數(shù)train,如代碼段2.3所示。由于Python函數(shù)中傳遞的numpy變量是引用,為了避免后續(xù)對(duì)數(shù)據(jù)集進(jìn)行分割時(shí)破壞原始數(shù)據(jù)集,首先在第36~37行執(zhí)行numpy.array的copy函數(shù)制作X和y的副本。然后在第41行進(jìn)行數(shù)據(jù)預(yù)處理,將X_copy和y_copy中的字符串通過(guò)__deal_value_map函數(shù)映射成數(shù)值。之后在第44~48行將numpy數(shù)組X_copy和y_copy轉(zhuǎn)換成Tensor數(shù)組,并根據(jù)self.use_gpu的值決定是否啟用GPU加速。其中,torch.from_numpy函數(shù)是PyTorch提供的內(nèi)置函數(shù),負(fù)責(zé)將numpy.array數(shù)組轉(zhuǎn)化成torch.Tensor格式,torch.Tensor.cuda函數(shù)也是PyTorch提供的內(nèi)置函數(shù),負(fù)責(zé)對(duì)當(dāng)前的tensor數(shù)組啟用GPU加速。最后,第51行進(jìn)入創(chuàng)建CART分類樹(shù)的核心函數(shù)__create_tree。
代碼段2.3 CART分類樹(shù)訓(xùn)練過(guò)程(源碼位于Chapter02/CartClassifier.py)

第41行提到了一個(gè)關(guān)鍵的字符串映射函數(shù)__deal_value_map,在這里我們對(duì)它進(jìn)行詳細(xì)介紹。之所以將X_copy和y_copy中的字符串通過(guò)該函數(shù)映射成數(shù)值,是因?yàn)樵趫?zhí)行計(jì)算時(shí),為了實(shí)現(xiàn)GPU加速,numpy.array需要轉(zhuǎn)換成PyTorch的torch.Tensor數(shù)據(jù)結(jié)構(gòu),而torch.Tensor僅支持?jǐn)?shù)值類型。具體實(shí)現(xiàn)過(guò)程見(jiàn)代碼段2.4。
代碼段2.4 建立字符串與數(shù)值的映射的函數(shù)__deal_value_map(源碼位于Chapter02/CartClassifier.py)


代碼段2.4展示了從字符串到數(shù)值建立映射的過(guò)程。首先,在第91~103行處理X,改變self.x_use_map的標(biāo)記,遍歷X中的每個(gè)元素,以元素值為key,以當(dāng)前self.str_map的長(zhǎng)度為value,在self.str_map中建立映射,同時(shí),在self.num_map中建立反向的映射。然后,在第106~117行處理y,同理,在y不為None的情況下,在當(dāng)前self.str_map和self.num_map的基礎(chǔ)上繼續(xù)建立字符串映射。最后,在第120~123行返回映射好的X和y。
代碼段2.5則實(shí)現(xiàn)將數(shù)值映射回字符串的功能。其中分為兩種情況:一種是使用了字符串到數(shù)值的映射的key(通過(guò)to_X、self.x_use_map和self.y_use_map的值可以判別,如第132~133行所示),此時(shí)使用self.num_map映射回字符串;另一種是沒(méi)有使用映射的key(如第134~137行所示),這種情況直接返回key或者key.item(key為T(mén)ensor中的數(shù)值類型時(shí))的值。
代碼段2.5 從數(shù)值到字符串的映射中還原值(源碼位于Chapter02/CartClassifier.py)


接下來(lái),我們回到代碼段2.3。在代碼段2.3的第51行調(diào)用了__create_tree函數(shù),它完成了決策樹(shù)的創(chuàng)建過(guò)程,其具體代碼見(jiàn)代碼段2.6。
代碼段2.6 創(chuàng)建CART分類樹(shù)的核心代碼(源碼位于Chapter02/CartClassifier.py)

在上述代碼段中,__create_tree函數(shù)是一個(gè)遞歸創(chuàng)建決策樹(shù)的過(guò)程。首先,在第147~153行判斷三種遞歸終止條件:X中樣本全部屬于同一類別、當(dāng)前節(jié)點(diǎn)樣本數(shù)小于self.min_samples_split、屬性集上的取值均相同。若滿足終止條件,則調(diào)用__get_value函數(shù)返回從數(shù)值到字符串的映射值,若未滿足終止條件,則繼續(xù)往下計(jì)算。然后,在第155~157行根據(jù)基尼增益從屬性值中選擇最優(yōu)分裂屬性的最優(yōu)切分點(diǎn),具體過(guò)程如__choose_best_point_to_split函數(shù)所示。最后,在第159~169行根據(jù)最優(yōu)切分點(diǎn)對(duì)子樹(shù)進(jìn)行劃分,對(duì)于其子樹(shù)再繼續(xù)執(zhí)行__create_tree函數(shù)完成劃分過(guò)程。
在代碼段2.6的第153行調(diào)用了__majority_y_id函數(shù),它用于計(jì)算節(jié)點(diǎn)中出現(xiàn)次數(shù)最多的類別,具體見(jiàn)代碼段2.7。它首先在第191行進(jìn)行合法性檢查,確保輸入?yún)?shù)y_tensor的元素個(gè)數(shù)大于0。然后在第193~197行初始化一個(gè)空dict,遍歷y_tensor并對(duì)其元素進(jìn)行計(jì)數(shù)。最后在第199~203行從字典y_count中查找出現(xiàn)次數(shù)最多的類別ID(或映射值)。
在代碼段2.6的第156行,調(diào)用了__choose_best_point_to_split函數(shù),它用于選擇最優(yōu)切分點(diǎn),具體見(jiàn)代碼段2.8。
代碼段2.7 計(jì)算節(jié)點(diǎn)中出現(xiàn)次數(shù)最多的類別(源碼位于Chapter02/CartClassifier.py)

代碼段2.8 選擇最優(yōu)切分點(diǎn)(源碼位于Chapter02/CartClassifier.py)


代碼段2.8是CART分類樹(shù)中最核心的函數(shù),該函數(shù)負(fù)責(zé)選擇最優(yōu)切分點(diǎn)。根據(jù)前面的理論推導(dǎo),該函數(shù)的目的是計(jì)算取得最大基尼增益的屬性值。首先在第219行調(diào)用__cal_gini_impurity函數(shù)計(jì)算總數(shù)據(jù)集的基尼不純度G(root)。然后在第220~237行遍歷每個(gè)屬性的每個(gè)屬性值,根據(jù)是否等于屬性值(二分類問(wèn)題)將數(shù)據(jù)集分割到左右子樹(shù),依次計(jì)算左右子樹(shù)的基尼不純度G(left)和G(right),以及左右子樹(shù)中數(shù)據(jù)樣本在總樣本中占的比例P(left)和P(right),并且將G(root)、G(left)、G(right)、P(left)和P(right)代入__cal_gini_gain函數(shù)中計(jì)算基尼增益。最后在第239~243行選出具有最大基尼增益的屬性值,作為當(dāng)前節(jié)點(diǎn)的最優(yōu)切分點(diǎn),并返回最優(yōu)切分點(diǎn)和最優(yōu)分裂屬性索引。
在代碼段2.8的第236行,調(diào)用__cal_gini_gain函數(shù)來(lái)計(jì)算基尼增益。在計(jì)算基尼增益之前,我們需要知道如何計(jì)算一個(gè)數(shù)據(jù)集的基尼不純度。如代碼段2.8的第229行和第232行所示,通過(guò)調(diào)用__cal_gini_impurity函數(shù)來(lái)計(jì)算基尼不純度,它的具體實(shí)現(xiàn)見(jiàn)代碼段2.9。
代碼段2.9 計(jì)算基尼不純度(源碼位于Chapter02/CartClassifier.py)

在代碼段2.9的第253~258行,我們分析導(dǎo)入的數(shù)據(jù)集的最后一列(一般默認(rèn)為數(shù)據(jù)類別),根據(jù)不同類別按出現(xiàn)次數(shù)統(tǒng)計(jì)到分類字典中。在第260~265行遍歷該字典,根據(jù)公式用1減去不同的類分布概率的平方和,得到最終的基尼不純度。接下來(lái)在計(jì)算基尼不純度的基礎(chǔ)上進(jìn)一步實(shí)現(xiàn)基尼增益的計(jì)算,即__cal_gini_gain函數(shù)。它的具體代碼見(jiàn)代碼段2.10。
代碼段2.10 計(jì)算基尼增益(源碼位于Chapter02/CartClassifier.py)

求解基尼指數(shù)的過(guò)程與求解基尼增益的過(guò)程有著相似之處,它們都需要?jiǎng)澐謹(jǐn)?shù)據(jù)求出基尼不純度,以及左右子樹(shù)中類的比例,只不過(guò)基尼指數(shù)不需要求總數(shù)據(jù)集的基尼不純度,而是將pro_left*gini_impurity_left與pro_right*gini_impurity_right累加求和。因此,在選擇最優(yōu)切分點(diǎn)時(shí),我們選擇具有最大基尼增益的屬性值,或者具有最小基尼指數(shù)的屬性值。
2.2.3.3 預(yù)測(cè)過(guò)程
代碼段2.11a和代碼段2.11b演示了CART分類樹(shù)進(jìn)行預(yù)測(cè)時(shí)的整體過(guò)程。在預(yù)測(cè)過(guò)程中,依然首先在代碼段2.11a的第61~69行拷貝數(shù)據(jù)集、處理字符串映射和判斷是否啟用GPU加速,然后在代碼段2.11a的第72行和代碼段2.11b的第178~182行遍歷測(cè)試集X_tensor的每個(gè)樣本,使用__classify函數(shù)分別對(duì)其進(jìn)行預(yù)測(cè),最終返回拼接好的預(yù)測(cè)結(jié)果。從代碼段2.11b的結(jié)構(gòu)中可以看出非常好的并行性,因此在多核CPU機(jī)器上處理大數(shù)據(jù)預(yù)測(cè)時(shí),使用多線程將其并行化可以大大提升預(yù)測(cè)效率,對(duì)此不做贅述。
代碼段2.11a CART分類樹(shù)預(yù)測(cè)過(guò)程(源碼位于Chapter02/CartClassifier.py)

代碼段2.11b CART分類樹(shù)預(yù)測(cè)過(guò)程(源碼位于Chapter02/CartClassifier.py)

在代碼段2.11b的第180行,通過(guò)調(diào)用__classify進(jìn)行預(yù)測(cè)分類,其具體代碼見(jiàn)代碼段2.12。在函數(shù)__classify的參數(shù)中,樹(shù)模型tree是字典結(jié)構(gòu),它的每?jī)蓪哟砹藢?shí)際意義上的一層決策樹(shù)。因此在第291~304行的遞歸遍歷過(guò)程中,每次取出tree的前兩層(根節(jié)點(diǎn)和根節(jié)點(diǎn)的左右孩子節(jié)點(diǎn)),其中根節(jié)點(diǎn)代表屬性,根節(jié)點(diǎn)的左右孩子節(jié)點(diǎn)代表屬性的取值及路由方向。根據(jù)以上特點(diǎn),從根節(jié)點(diǎn)開(kāi)始,遞歸遍歷CART分類樹(shù),最終路由到某個(gè)葉子節(jié)點(diǎn),葉子節(jié)點(diǎn)上的值即為該決策樹(shù)的預(yù)測(cè)結(jié)果。
代碼段2.12 CART分類樹(shù)預(yù)測(cè)的核心代碼(源碼位于Chapter02/CartClassifier.py)

2.2.3.4 可視化過(guò)程
對(duì)于可視化過(guò)程,可以借助Matplotlib庫(kù)來(lái)實(shí)現(xiàn),為此,我們結(jié)合樹(shù)的遍歷特點(diǎn),封裝了一套適用于上述決策樹(shù)的tree_plotter可視化包。
在代碼段2.13a和2.13b中,tree_plot函數(shù)為該包對(duì)外提供的決策樹(shù)繪制接口,其整體算法的思路可分為兩個(gè)步驟:首先繪制自身節(jié)點(diǎn),然后判斷自身節(jié)點(diǎn)類型,若為非葉子節(jié)點(diǎn)則繼續(xù)遞歸創(chuàng)建子樹(shù),若為葉子節(jié)點(diǎn)則直接繪制。關(guān)于更詳細(xì)的實(shí)現(xiàn)細(xì)節(jié),請(qǐng)讀者自行閱讀源碼,由于篇幅原因在此不再展開(kāi)。
代碼段2.13a tree_plotter可視化包(源碼位于Chapter02/tree_plotter.py)

代碼段2.13b tree_plotter可視化包(源碼位于Chapter02/tree_plotter.py)


- 超AI入門(mén)
- 共生:科技與社會(huì)驅(qū)動(dòng)的數(shù)字化未來(lái)
- 智能無(wú)線機(jī)器人:人工智能算法與應(yīng)用
- 人工智能視域下機(jī)器學(xué)習(xí)在教育研究中的應(yīng)用
- 里武林的沉淪囈語(yǔ):AI人工智能游戲概念設(shè)定集
- 中國(guó)人工智能創(chuàng)新鏈產(chǎn)業(yè)鏈技術(shù)專利發(fā)展研究
- 人工智能核心:神經(jīng)網(wǎng)絡(luò)(青少科普版)
- 因果推斷導(dǎo)論
- 深度學(xué)習(xí)視頻理解
- 揭秘大模型:從原理到實(shí)戰(zhàn)
- 增強(qiáng)型分析:人工智能技術(shù)驅(qū)動(dòng)的數(shù)據(jù)分析、業(yè)務(wù)決策與案例實(shí)踐
- 巧用ChatGPT輕松玩轉(zhuǎn)新媒體運(yùn)營(yíng)
- 傳感器技術(shù)及應(yīng)用
- AI加速器架構(gòu)設(shè)計(jì)與實(shí)現(xiàn)
- 人工智能:人臉識(shí)別與搜索