分类

分类

没有名字

2025-09-24 发布12 浏览 · 0 点赞 · 0 收藏

模型类构建
定义 Lam 神经网络类,继承 nn.Module,在 init 方法中通过 nn.Sequential 搭建卷积神经网络结构,包含多层 Conv2d 卷积、MaxPool2d 池化、Flatten 展平及 Linear 全连接层,实现从 3 通道图像输入到 10 分类输出的特征提取与映射;编写 forward 方法,完成正向传播逻辑,保障模型可接收输入并输出预测结果 。同时,编写 if name == 'main': 测试代码,用 torch.ones 生成模拟输入,验证模型前向计算时输出维度的合理性,确保模型结构无基础错误 。

模型训练脚本开发
基于 Lam 模型,编写训练流程代码。加载 CIFAR10 数据集,用 torchvision.datasets.CIFAR10 获取训练集与测试集,通过 DataLoader 实现数据批量加载;配置损失函数为 CrossEntropyLoss、优化器为 SGD;设置训练轮次 epoch = 50 等参数,编写训练循环,在训练阶段开启模型 train 模式,逐批处理数据,计算损失、反向传播更新参数,定期输出训练日志;测试阶段切换模型 eval 模式,计算测试集损失与准确率,完整实现模型训练闭环 。

编写推理代码,加载训练好的模型(lam_49_gpu.pth ),处理输入图像(读取、转换通道、Resize、转 Tensor 等),通过 model.eval() 与 torch.no_grad() 保障推理阶段的稳定性,输出预测类别,验证模型实际应用时的推理能力 。


当训练至第100轮时,准确率和Loss均提升

请前往 登录/注册 即可发表您的看法…