- 機器學習系統:設計和實現
- 麥絡 董豪編著
- 170字
- 2024-12-27 20:30:17
2.2.3 模型定義
使用MindSpore定義神經網絡模型需要繼承mindspore.nn.Cell,神經網絡的各層需要預先在__init__方法中定義,然后重載__construct__方法實現神經網絡的前向傳播過程。因為輸入大小處理成32×32的圖片,所以需要用Flatten操作將數據壓平為一維向量后給全連接層。全連接層輸入大小為32×32,輸出是預測0~9中的哪個數字,因此輸出大小為10。代碼2.4定義了一個三層全連接網絡模型。
代碼2.4 定義三層全連接網絡模型
