计算化学公社

 找回密码 Forget password
 注册 Register
Views: 2533|回复 Reply: 8
打印 Print 上一主题 Last thread 下一主题 Next thread

[Python] Linux下利用PyTorch对回归与分类任务进行深度学习的脚本

[复制链接 Copy URL]

1102

帖子

18

威望

6643

eV
积分
8105

Level 6 (一方通行)

計算化学の社畜

本帖最后由 冰释之川 于 2022-12-30 18:04 编辑

相关机器学习的帖子:

Linux下Scikit-learn机器学习实例脚本》(http://bbs.keinsci.com/thread-30547-1-1.html
Linux下利用SHAP对机器学习模型进行合理解释》(http://bbs.keinsci.com/thread-30780-1-1.html

本文简单介绍了如何利用PyTorch进行神经网络建模与训练,这里感谢mizu-bai同学提供的PIP-NN训练教程(https://github.com/mizu-bai/PIP-NN-PyTorch-Tutorial)

本脚本将涉及到如下几个方面的代码编写与应用(以回归任务为例):
(1)神经网络在PyTorch中的定义(涉及到python中的class概念)
(2)将初始的Dataframe架构的数据转换成DataLoader可识别的格式
(3)数据的缩放(利用scikit-learn包辅助实现)
(4)神经网络训练过程与验证过程
(5)交叉验证策略的实现形式(数据集划分部分利用scikit-learn包辅助实现)
(6)最佳模型的保存与加载
(7)回归模型的常用评价指标(利用scikit-learn包辅助实现)



一、脚本依赖包清单



二、在PyTorch中定义神经网络

__init__函数初始化神经网络,其中num_features代表数据集里输入变量X的维度(总特征数量)
forward()函数用于调用__init__中的layer_stack对象,输出预测的y值


三、在PyTorch中进行数据集格式转换

在这里我提供的数据集的初始格式为DataFrame和Series,利用该class对数据格式进行转换,使之适应PyTorch中的DataLoader工具

四、数据缩放

数据处理是整个深度学习的一个重要的组成部分,这里简单示范一下如何对所有的X特征量进行缩放。
缩放的时候要注意的一点是:对scaler进行拟合的时候一定一定要用训练集数据(不要用训练+测试整个数据集),以防数据泄露,使得最终建立的模型表现过于乐观。
然后利用transform()方法分别对训练集与测试集进行缩放


五、数据分割



六、神经网络训练过程


这里预先定义了损失函数,优化器算法以及利用ReduceLROnPlateau对学习率进行自动调控
红框子里的5行代码是核心代码,分别代表(1)计算预测值; (2)计算损失函数;(3)对优化器上一轮保留的梯度信息进行清零;(4)计算梯度并进行反向传播; (5)根据梯度来更新权重参数

七、神经网络验证过程

验证过程相对比较简单,利用model.eval()切换到验证模式,然后为了节约计算时间,关闭验证过程中的梯度计算功能,最后得到验证集下的损失函数值

八、利用训练集对模型进行K折交叉验证

交叉验证是充分利用训练集进行机器学习一种常用策略,这里借用scikit-learn包里的工具对训练集进行划分,然后对计算的损失函数值进行平均
不了解交叉验证的同学,参看下面的简介后肯定秒懂了


九、利用测试集对最佳的神经网络模型进行泛化能力评估

最后对最佳的模型(所有Epoch中平均loss最低的Epoch对应的模型)进行进行泛化能力评估,主要衡量指标为 MSE, MAE,以及R方

最后附上脚本运行过程的输出信息作为参考:



脚本与数据集下载:
PyTorch.7z (200.98 KB, 下载次数 Times of downloads: 144) (2022.12.30 更新)




评分 Rate

参与人数
Participants 20
威望 +2 eV +84 收起 理由
Reason
LiHuaYu + 2 <font style="vertical-align: inh
chenbq18 + 5 谢谢
JamesBourbon + 4 好物!
ShengLin + 4 好物!
卡开发发 + 5 233333
sobereva + 2
luwis + 5 谢谢
鬼隐 + 5
Gzh_NJ + 5 谢谢
chands + 5 GJ!
LittlePupil + 5 精品内容
an2000 + 4 好物!
Yjc + 4 精品内容
shalene + 5 好物!
swordshine + 3 赞!
zsu007 + 5 赞!
丁越 + 5 赞!
hebrewsnabla + 3 精品内容
含光君 + 5 精品内容
ChrisZheng + 5 хорошо!

查看全部评分 View all ratings

Stand on the shoulders of giants

66

帖子

0

威望

1473

eV
积分
1539

Level 5 (御坂)

2#
发表于 Post on 2022-12-17 03:35:54 | 只看该作者 Only view this author
游客,本帖隐藏的内容需要积分高于 25 才可浏览,您当前积分为 0

1102

帖子

18

威望

6643

eV
积分
8105

Level 6 (一方通行)

計算化学の社畜

3#
 楼主 Author| 发表于 Post on 2022-12-17 07:40:44 | 只看该作者 Only view this author
本帖最后由 冰释之川 于 2022-12-17 09:41 编辑
luwis 发表于 2022-12-17 03:35
**** 本内容被作者隐藏 ****

脚本里自动判断是否有cuda,有的话直接启用gpu了才对。我一会再检查一下问题所在,感谢反馈bug
Stand on the shoulders of giants

1102

帖子

18

威望

6643

eV
积分
8105

Level 6 (一方通行)

計算化学の社畜

4#
 楼主 Author| 发表于 Post on 2022-12-17 10:20:57 | 只看该作者 Only view this author
luwis 发表于 2022-12-17 03:35
**** 本内容被作者隐藏 ****

bug已修复,请下载最新版
Stand on the shoulders of giants

66

帖子

0

威望

1473

eV
积分
1539

Level 5 (御坂)

5#
发表于 Post on 2022-12-17 11:01:17 | 只看该作者 Only view this author
冰释之川 发表于 2022-12-17 10:20
bug已修复,请下载最新版

测试成功。谢谢冰大佬。
感觉不搞机器学习,就是错过了一个时代。
我的困惑是,我能力最多是跟随应用,做不多啥有价值的东西。
冰大佬,请您说说机器学习的不足主要在哪些方面?如果想搞机器学习的话,我们在哪些方面喝点汤?
大家都来说说呗。谢谢!

49

帖子

0

威望

476

eV
积分
525

Level 4 (黑子)

6#
发表于 Post on 2022-12-18 15:10:01 | 只看该作者 Only view this author
太厉害了,学习了!

5

帖子

0

威望

23

eV
积分
28

Level 2 能力者

7#
发表于 Post on 2023-1-12 14:36:02 | 只看该作者 Only view this author
我是真想好好学习啊

63

帖子

0

威望

703

eV
积分
766

Level 4 (黑子)

8#
发表于 Post on 2023-1-20 15:10:45 | 只看该作者 Only view this author
有个疑问,学校有matlab最近有沙龙教学用matlab 做机器学习,想问下如果使用matlab的机器学习能否达到一些高通量  筛选之类的计算

1102

帖子

18

威望

6643

eV
积分
8105

Level 6 (一方通行)

計算化学の社畜

9#
 楼主 Author| 发表于 Post on 2023-1-20 17:51:48 | 只看该作者 Only view this author
madhatter 发表于 2023-1-20 15:10
有个疑问,学校有matlab最近有沙龙教学用matlab 做机器学习,想问下如果使用matlab的机器学习能否达到一些 ...

没用过MATLAB,不过只要顺手,都能进行批量计算
Stand on the shoulders of giants

本版积分规则 Credits rule

手机版 Mobile version|北京科音自然科学研究中心 Beijing Kein Research Center for Natural Sciences|京公网安备 11010502035419号|计算化学公社 — 北京科音旗下高水平计算化学交流论坛 ( 京ICP备14038949号-1 )|网站地图

GMT+8, 2024-11-27 12:30 , Processed in 0.198644 second(s), 25 queries , Gzip On.

快速回复 返回顶部 返回列表 Return to list