- 機器學習系統:設計和實現
- 麥絡 董豪編著
- 1082字
- 2024-12-27 20:30:21
2.4.2 添加C++編寫的自定義算子
算子是構建神經網絡的基礎,也稱為低級API;通過算子的封裝可以實現各類神經網絡層,當開發神經網絡層遇到內置算子無法滿足時,可以通過自定義算子實現。以MindSpore為例,實現一個GPU算子需要如下步驟:
(1)原語(Primitive)注冊:算子原語是構建網絡模型的基礎單元,用戶可以直接或者間接調用算子原語搭建一個神經網絡模型。
(2)GPU算子開發:GPU Kernel用于調用GPU實現加速計算。
(3)GPU算子注冊:算子注冊用于將GPU Kernel及必要信息注冊給框架,由框架完成對GPU Kernel的調用。
1.注冊算子原語
算子原語通常包括算子名、算子輸入、算子屬性[初始化時需要填的參數,如卷積的步長(stride)、填充(padding)]、輸入數據合法性校驗、輸出數據類型推導和維度推導。假設需要編寫加法算子,主要內容如下:
(1)算子名:TensorAdd。
(2)算子屬性:在構造函數__init__中初始化屬性,因加法沒有屬性,因此__init__不需要額外輸入。
(3)算子輸入/輸出及合法性校驗:infer_shape方法中約束兩個輸入維度必須相同,輸出的維度和輸入維度相同。infer_dtype方法中約束兩個輸入數據必須是float32類型,輸出的數據類型和輸入數據類型相同。
(4)算子輸出。
MindSpore中實現注冊TensorAdd,如代碼2.15所示。
代碼2.15 MindSpore實現注冊TensorAdd

在mindspore/ops/operations/math_ops.py文件內注冊加法算子原語后,需要在mindspore/ops/operations/__init__中導出,方便Python導入模塊時候調用,如代碼2.16所示。
代碼2.16 導出注冊算子

2.GPU算子開發
繼承GPU Kernel,實現加法使用類模板定義TensorAddGpuKernel,需要實現以下方法:
(1)Init():用于完成GPU Kernel的初始化,通常包括記錄算子輸入/輸出維度,完成Launch前的準備工作;因此在此記錄Tensor元素個數。
(2)GetInputSizeList():向框架反饋輸入Tensor需要占用的顯存字節數;返回了輸入Tensor需要占用的字節數,TensorAdd有兩個輸入,每個輸入占用字節數為element_num?sizeof(T)。
(3)GetOutputSizeList():向框架反饋輸出Tensor需要占用的顯存字節數;返回了輸出Tensor需要占用的字節數,TensorAdd有一個輸出,占用element_num?sizeof(T)字節。
(4)GetWorkspaceSizeList():向框架反饋工作空間(Workspace)字節數,工作空間是用于計算過程中存放臨時數據的空間;由于TensorAdd不需要工作空間,因此GetWorkspace-SizeList()返回空的std::vector<size_t>。
(5)Launch():通常調用CUDA Kernel(CUDA Kernel是基于Nvidia GPU的并行計算架構開發的核函數),或者cuDNN接口等方式,完成算子在GPU上加速;Launch()接收輸入、輸出在顯存的地址,接著調用TensorAdd完成加速。
GPU算子開發參見代碼2.17。
代碼2.17 GPU算子開發

TensorAdd中調用了CUDA kernelTensorAddKernel來實現element_num個元素的并行相加,如代碼2.18所示。
代碼2.18 實現并行相加

3.GPU算子注冊
GPU算子信息包含:①Primive;②Input dtype,output dtype;③GPU Kernel class;④CUDA內置數據類型。框架會根據Primive和Input dtype、output dtype,調用以CUDA內置數據類型實例化GPU Kernel class模板類。代碼2.19中分別注冊了支持float(浮點型數)和int(整型數)的TensorAdd算子。
代碼2.19 GPU算子注冊

完成上述三步工作后,需要把MindSpore重新編譯,在源碼的根目錄執行bash build.sh-e gpu,最后使用算子進行驗證。
- Mastering Natural Language Processing with Python
- C和C++安全編碼(原書第2版)
- 深入淺出Android Jetpack
- Mastering Yii
- Mastering macOS Programming
- 編程數學
- Getting Started with Laravel 4
- 零基礎學Python網絡爬蟲案例實戰全流程詳解(入門與提高篇)
- 執劍而舞:用代碼創作藝術
- 基于SpringBoot實現:Java分布式中間件開發入門與實戰
- OpenStack Networking Essentials
- Mastering Docker
- Magento 2 Beginners Guide
- Simulation for Data Science with R
- Monitoring Docker