官术网_书友最值得收藏!

3.3 用Python實現橫向聯邦圖像分類

本節我們使用Python從零開始實現一個簡單的橫向聯邦學習模型。具體來說,我們將用橫向聯邦來實現對cifar10圖像數據集的分類,模型使用的是ResNet-18。我們將分別從服務端、客戶端和配置文件三個角度詳細講解設計一個橫向聯邦所需要的基本操作。

需要注意的是,為了方便實現,本章沒有采用網絡通信的方式來模擬客戶端和服務端的通信,而是在本地以循環的方式來模擬。在第10章中,我們將介紹利用Flask-SocketIO模擬客戶端和服務端進行網絡通信的實現。

3.3.1 配置信息

聯邦學習在開發過程中會涉及大量的參數配置,其中比較常用的參數設置包括以下幾個。

? 訓練的客戶端數量:每一輪的迭代,服務端會首先從所有的客戶端中挑選部分客戶端進行本地訓練。每一次迭代只選取部分客戶端參與,并不會影響全局收斂的效果,且能夠提升訓練的效率[200]

? 全局迭代次數:即服務端和客戶端的通信次數。通常會設置一個最大的全局迭代次數,但在訓練過程中,只要模型滿足收斂的條件,那么訓練也可以提前終止。

? 本地模型的迭代次數:即每一個客戶端在進行本地模型訓練時的迭代次數。每一個客戶端的本地模型的迭代次數可以相同,也可以不同。

? 本地訓練相關的算法配置:本地模型進行訓練時的參數設置,如學習率(lr)、訓練樣本大小、使用的優化算法等。

? 模型信息:即當前任務我們使用的模型結構。在本案例中,我們使用ResNet-18圖像分類模型[127]

? 數據信息:聯邦學習訓練的數據。在本案例中,我們將使用cifar10數據集。為了模擬橫向建模,數據集將按樣本維度,切分為多份不重疊的數據,每一份放置在每一個客戶端中作為本地訓練數據。

其他的配置信息,比如可能使用到的加密方案、是否使用差分隱私、模型是否需要檢查點文件(checkpoint)、模型聚合的策略等,都可以根據實際需要自行添加或者修改。我們將上面的信息以json格式記錄在配置文件中以便修改,如下所示。

聯邦學習在模型訓練之前,會將配置信息分別發送到服務端和客戶端中保存,如果配置信息發生改變,也會同時對所有參與方進行同步,以保證各參與方的配置信息一致。

3.3.2 訓練數據集

按照上述配置文件中的type字段信息,獲取數據集。這里我們使用torchvision的datasets模塊內置的cifar10數據集。如果要使用其他數據集,讀者可以自行修改。

3.3.3 服務端

橫向聯邦學習的服務端的主要功能是將被選擇的客戶端上傳的本地模型進行模型聚合。但這里需要特別注意的是,事實上,對于一個功能完善的聯邦學習框架,比如我們將在后面介紹的FATE平臺,服務端的功能要復雜得多,比如服務端需要對各個客戶端節點進行網絡監控、對失敗節點發出重連信號等。本章由于是在本地模擬的,不涉及網絡通信細節和失敗故障等處理,因此不討論這些功能細節,僅涉及模型聚合功能。

下面我們定義一個服務端類Server,類中的主要函數包括以下三種。

? 定義構造函數。在構造函數中,服務端的工作包括:第一,將配置信息拷貝到服務端中;第二,按照配置中的模型信息獲取模型,這里我們使用torchvision的models模塊內置的ResNet-18模型。torchvision內置了很多常見的模型(鏈接3-5)。模型下載后,令其作為全局初始模型。

? 定義模型聚合函數。前面我們提到服務端的主要功能是進行模型的聚合,因此定義構造函數后,我們需要在類中定義模型聚合函數,通過接收客戶端上傳的模型,使用聚合函數更新全局模型。聚合方案有很多種,本節我們采用經典的FedAvg算法[200]。FedAvg算法通過使用下面的公式來更新全局模型:

其中,Gt表示第t輪聚合之后的全局模型,表示第i個客戶端在第t+1輪本地更新后的模型,Gt+1表示第t+1輪聚合之后的全局模型。算法代碼如下所示。

? 定義模型評估函數。對當前的全局模型,利用評估數據評估當前的全局模型性能。通常情況下,服務端的評估函數主要對當前聚合后的全局模型進行分析,用于判斷當前的模型訓練是需要進行下一輪迭代、還是提前終止,或者模型是否出現發散退化的現象。根據不同的結果,服務端可以采取不同的措施策略。

3.3.4 客戶端

橫向聯邦學習的客戶端主要功能是接收服務端的下發指令和全局模型,并利用本地數據進行局部模型訓練。

與前一節一樣,對于一個功能完善的聯邦學習框架,客戶端的功能也相當復雜,比如需要考慮本地的資源(CPU、內存等)是否滿足訓練需要、當前的網絡中斷、當前的訓練由于受到外界因素影響而中斷等。讀者如果對這些設計細節感興趣,可以查看當前流行的聯邦學習框架源代碼和文檔,比如FATE,獲取更多的細節。

本節我們僅考慮客戶端本地的模型訓練細節。我們首先定義客戶端類Client,類中的主要函數包括以下兩種。

? 定義構造函數。在客戶端構造函數中,客戶端的主要工作包括:首先,將配置信息拷貝到客戶端中;然后,按照配置中的模型信息獲取模型,通常由服務端將模型參數傳遞給客戶端,客戶端將該全局模型覆蓋掉本地模型;最后,配置本地訓練數據,在本案例中,我們通過torchvision的datasets模塊獲取cifar10數據集后按客戶端ID進行切分,不同的客戶端擁有不同的子數據集,相互之間沒有交集。

? 定義模型本地訓練函數。本例是一個圖像分類的例子,因此,我們使用交叉熵作為本地模型的損失函數,利用梯度下降來求解并更新參數值,實現細節如下面代碼塊所示。

3.3.5 整合

當配置文件、服務端類和客戶端類都定義完畢后,我們將這些信息組合起來。首先,讀取配置文件信息。

接下來,我們將分別定義一個服務端對象和多個客戶端對象,用來模擬橫向聯邦訓練場景。

每一輪的迭代,服務端會從當前的客戶端集合中隨機挑選一部分參與本輪迭代訓練,被選中的客戶端調用本地訓練接口local_train進行本地訓練,最后服務端調用模型聚合函數model_aggregate來更新全局模型,代碼如下所示。

模型聚合完畢后,調用模型評估接口來評估每一輪更新后的全局模型效果。完整的代碼請參見本書配套的GitHub網頁。

主站蜘蛛池模板: 杂多县| 永靖县| 武威市| 宁强县| 清远市| 时尚| 醴陵市| 平乡县| 遂平县| 桐乡市| 临泉县| 都匀市| 宝兴县| 林芝县| 贵德县| 永丰县| 本溪市| 海淀区| 北海市| 务川| 彭阳县| 伊宁县| 石泉县| 内江市| 鹤庆县| 四平市| 海宁市| 石楼县| 衡南县| 孝义市| 潞西市| 桐庐县| 库车县| 湘潭市| 梁山县| 姚安县| 舒兰市| 莱芜市| 理塘县| 天祝| 绍兴县|