- 機器學習系統:設計和實現
- 麥絡 董豪編著
- 306字
- 2024-12-27 20:30:18
2.2.5 訓練及保存模型
MindSpore提供了回調(Callback)機制,可以在訓練過程中執行自定義邏輯。代碼2.6使用框架提供的ModelCheckpoint函數,ModelCheckpoint函數可以保存網絡模型和參數,以便進行后續的Fine-tuning(微調)操作。
代碼2.6 定義模型保存

通過MindSpore提供的model.train接口可以方便地進行網絡的訓練,同時使用Loss-Monitor可以監控訓練過程中損失(loss)值的變化,如代碼2.7所示。
代碼2.7 定義模型訓練

其中,dataset_sink_mode用于控制數據是否下沉,數據下沉是指數據通過通道直接傳送到設備(Device)上,可以加快訓練速度,dataset_sink_mode為真(True),表示數據下沉,否則為非下沉。
有了數據集、模型、損失函數、優化器后就可以進行訓練了。代碼2.8把train_epoch設置為1,對數據集進行1次迭代訓練。在train_net方法中,加載了之前下載的訓練數據集,mnist_path是MNIST數據集路徑。
代碼2.8 訓練模型

推薦閱讀