pytorch学习记录
基本操作
导包
1 | import torch |
下载及导入数据
1 | # Download training data from open datasets. |
分批量读取数据
1 | batch_size = 64 |
定义神经网络模型
1 | device=torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" |
损失函数和优化算法定义
1 | loss_fn=nn.CrossEntropyLoss()#因为这里是分类问题,所以选择交叉熵损失 |
定义训练函数
1 | #训练和测试都是传入dataloader,model,loss_function和optimizer |
定义测试函数
1 | #与上述类似 |
正式训练
1 | epochs = 5 |
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来源 ようこそ、わが楽園へ!!
