计算化学公社

 找回密码 Forget password
 注册 Register
Views: 1005|回复 Reply: 8

[Python] Linux下利用PyTorch对回归与分类任务进行深度学习的脚本

[复制链接 Copy URL]

1061

帖子

16

威望

5791

eV
积分
7172

Level 6 (一方通行)

計算化学の社畜

发表于 Post on 2022-12-15 20:09:16 | 显示全部楼层 Show all |阅读模式 Reading model
本帖最后由 冰释之川 于 2022-12-30 18:04 编辑

相关机器学习的帖子:

Linux下Scikit-learn机器学习实例脚本》(http://bbs.keinsci.com/thread-30547-1-1.html
Linux下利用SHAP对机器学习模型进行合理解释》(http://bbs.keinsci.com/thread-30780-1-1.html

本文简单介绍了如何利用PyTorch进行神经网络建模与训练,这里感谢mizu-bai同学提供的PIP-NN训练教程(https://github.com/mizu-bai/PIP-NN-PyTorch-Tutorial)

本脚本将涉及到如下几个方面的代码编写与应用(以回归任务为例):
(1)神经网络在PyTorch中的定义(涉及到python中的class概念)
(2)将初始的Dataframe架构的数据转换成DataLoader可识别的格式
(3)数据的缩放(利用scikit-learn包辅助实现)
(4)神经网络训练过程与验证过程
(5)交叉验证策略的实现形式(数据集划分部分利用scikit-learn包辅助实现)
(6)最佳模型的保存与加载
(7)回归模型的常用评价指标(利用scikit-learn包辅助实现)



一、脚本依赖包清单
202212151931387990..png


二、在PyTorch中定义神经网络
202212151926112154..png
__init__函数初始化神经网络,其中num_features代表数据集里输入变量X的维度(总特征数量)
forward()函数用于调用__init__中的layer_stack对象,输出预测的y值


三、在PyTorch中进行数据集格式转换
202212151929304841..png
在这里我提供的数据集的初始格式为DataFrame和Series,利用该class对数据格式进行转换,使之适应PyTorch中的DataLoader工具

四、数据缩放
202212151933079249..png
数据处理是整个深度学习的一个重要的组成部分,这里简单示范一下如何对所有的X特征量进行缩放。
缩放的时候要注意的一点是:对scaler进行拟合的时候一定一定要用训练集数据(不要用训练+测试整个数据集),以防数据泄露,使得最终建立的模型表现过于乐观。
然后利用transform()方法分别对训练集与测试集进行缩放


五、数据分割
202212151938391951..png


六、神经网络训练过程
202212151945523399..png
202212151945028702..png
这里预先定义了损失函数,优化器算法以及利用ReduceLROnPlateau对学习率进行自动调控
红框子里的5行代码是核心代码,分别代表(1)计算预测值; (2)计算损失函数;(3)对优化器上一轮保留的梯度信息进行清零;(4)计算梯度并进行反向传播; (5)根据梯度来更新权重参数

七、神经网络验证过程
202212151951319486..png
验证过程相对比较简单,利用model.eval()切换到验证模式,然后为了节约计算时间,关闭验证过程中的梯度计算功能,最后得到验证集下的损失函数值

八、利用训练集对模型进行K折交叉验证
202212151957195119..png
交叉验证是充分利用训练集进行机器学习一种常用策略,这里借用scikit-learn包里的工具对训练集进行划分,然后对计算的损失函数值进行平均
不了解交叉验证的同学,参看下面的简介后肯定秒懂了
202212152003117470..png

九、利用测试集对最佳的神经网络模型进行泛化能力评估
202212151954014548..png
最后对最佳的模型(所有Epoch中平均loss最低的Epoch对应的模型)进行进行泛化能力评估,主要衡量指标为 MSE, MAE,以及R方

最后附上脚本运行过程的输出信息作为参考:
202212152007437503..png


脚本与数据集下载:
PyTorch.7z (200.98 KB, 下载次数 Times of downloads: 87)

评分 Rate

参与人数
Participants 19
威望 +2 eV +82 收起 理由
Reason
chenbq18 + 5 谢谢
JamesBourbon + 4 好物!
ShengLin + 4 好物!
卡开发发 + 5 233333
sobereva + 2
luwis + 5 谢谢
鬼隐 + 5
Gzh_NJ + 5 谢谢
chands + 5 GJ!
LittlePupil + 5 精品内容
an2000 + 4 好物!
Yjc + 4 精品内容
shalene + 5 好物!
swordshine + 3 赞!
zsu007 + 5 赞!
丁越 + 5 赞!
hebrewsnabla + 3 精品内容
含光君 + 5 精品内容
ChrisZheng + 5 хорошо!

查看全部评分 View all ratings

Stand on the shoulders of giants

64

帖子

0

威望

1411

eV
积分
1475

Level 4 (黑子)

发表于 Post on 2022-12-17 03:35:54 | 显示全部楼层 Show all
游客,本帖隐藏的内容需要积分高于 25 才可浏览,您当前积分为 0

1061

帖子

16

威望

5791

eV
积分
7172

Level 6 (一方通行)

計算化学の社畜

 楼主 Author| 发表于 Post on 2022-12-17 07:40:44 | 显示全部楼层 Show all
本帖最后由 冰释之川 于 2022-12-17 09:41 编辑
luwis 发表于 2022-12-17 03:35
**** 本内容被作者隐藏 ****

脚本里自动判断是否有cuda,有的话直接启用gpu了才对。我一会再检查一下问题所在,感谢反馈bug
Stand on the shoulders of giants

1061

帖子

16

威望

5791

eV
积分
7172

Level 6 (一方通行)

計算化学の社畜

 楼主 Author| 发表于 Post on 2022-12-17 10:20:57 | 显示全部楼层 Show all
luwis 发表于 2022-12-17 03:35
**** 本内容被作者隐藏 ****

bug已修复,请下载最新版
Stand on the shoulders of giants

64

帖子

0

威望

1411

eV
积分
1475

Level 4 (黑子)

发表于 Post on 2022-12-17 11:01:17 | 显示全部楼层 Show all
冰释之川 发表于 2022-12-17 10:20
bug已修复,请下载最新版

测试成功。谢谢冰大佬。
感觉不搞机器学习,就是错过了一个时代。
我的困惑是,我能力最多是跟随应用,做不多啥有价值的东西。
冰大佬,请您说说机器学习的不足主要在哪些方面?如果想搞机器学习的话,我们在哪些方面喝点汤?
大家都来说说呗。谢谢!

48

帖子

0

威望

464

eV
积分
512

Level 4 (黑子)

发表于 Post on 2022-12-18 15:10:01 | 显示全部楼层 Show all
太厉害了,学习了!

4

帖子

0

威望

17

eV
积分
21

Level 1 能力者

发表于 Post on 2023-1-12 14:36:02 | 显示全部楼层 Show all
我是真想好好学习啊

51

帖子

0

威望

416

eV
积分
467

Level 3 能力者

发表于 Post on 2023-1-20 15:10:45 | 显示全部楼层 Show all
有个疑问,学校有matlab最近有沙龙教学用matlab 做机器学习,想问下如果使用matlab的机器学习能否达到一些高通量  筛选之类的计算

1061

帖子

16

威望

5791

eV
积分
7172

Level 6 (一方通行)

計算化学の社畜

 楼主 Author| 发表于 Post on 2023-1-20 17:51:48 | 显示全部楼层 Show all
madhatter 发表于 2023-1-20 15:10
有个疑问,学校有matlab最近有沙龙教学用matlab 做机器学习,想问下如果使用matlab的机器学习能否达到一些 ...

没用过MATLAB,不过只要顺手,都能进行批量计算
Stand on the shoulders of giants

本版积分规则 Credits rule

手机版 Mobile version|北京科音自然科学研究中心 Beijing Kein Research Center for Natural Sciences|京公网安备 11010502035419号|计算化学公社 — 北京科音旗下高水平计算化学交流论坛 ( 京ICP备14038949号-1 )|网站地图

GMT+8, 2023-2-7 02:28 , Processed in 0.442327 second(s), 26 queries .

快速回复 返回顶部 返回列表 Return to list