- 機器學習系統:設計和實現
- 麥絡 董豪編著
- 367字
- 2024-12-27 20:30:20
2.3.3 自定義神經網絡層
2.3.1節中使用偽代碼定義機器學習庫中低級API,有了實現的神經網絡基類抽象方法,那么就可以設計更高層次的接口,解決手動管理參數的煩瑣。假設已經有了神經網絡模型抽象方法Cell,構建Conv2D將繼承Cell類,并重構__init__和__call__方法,在__init__方法中初始化訓練參數和輸入參數,在__call__方法中調用低級API實現計算邏輯。使用偽代碼,如代碼2.12所示,通過接口定義描述自定義卷積層的過程。
代碼2.12 自定義神經網絡層

有了上述定義,在使用卷積層時,就不需要創建訓練變量了。假設需要對30×30大小的10個通道的輸入使用3×3的卷積核做卷積,卷積后輸出20個通道,調用方式如代碼2.13所示。
代碼2.13 使用卷積層

在執行過程中,初始化Conv2D時,__setattr__方法會判斷屬性,把屬于Cell類的神經網絡層Conv2D記錄到self._cells中,filters屬于參數(parameter),把參數記錄到self._params中。查看神經網絡層參數使用conv.parameters_and_names;查看神經網絡層列表使用conv.cells_and_names;執行操作使用conv(inputs)。
推薦閱讀
- 動手玩轉Scratch3.0編程:人工智能科創教育指南
- Apache Spark 2 for Beginners
- TestNG Beginner's Guide
- Mastering macOS Programming
- Hands-On Enterprise Automation with Python.
- C語言程序設計案例精粹
- Python機器學習算法與實戰
- Java Web開發詳解
- Natural Language Processing with Java and LingPipe Cookbook
- C語言程序設計簡明教程:Qt實戰
- Processing創意編程指南
- C++語言程序設計
- Learning iOS Security
- JQuery風暴:完美用戶體驗
- R語言實戰(第2版)