博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
PyTorch学习笔记之DataLoaders
阅读量:5288 次
发布时间:2019-06-14

本文共 1699 字,大约阅读时间需要 5 分钟。

A DataLoader wraps a Dataset and provides minibatching, shuffling, multithreading, for you。

1 import torch 2 from torch.autograd import Variable 3 import torch.nn as nn 4 from torch.utils.data import TensorDataset, DataLoader 5  6 # define our whole model as a single Module 7 class TwoLayerNet(nn.Module): 8     # Initializer sets up two children (Modules can contain modules) 9     def _init_(self, D_in, H, D_out):10         super(TwoLayerNet, self)._init_()11         self.linear1 = torch.nn.Linear(D_in, H)12         self.linear2 = torch.nn.Linear(H, D_out)13 14     # Define forward pass using child modules and autograd ops on Variables15     # No need to define backward - autograd will handle it16     def forward(self, x):17         h_relu = self.linear1(x).clamp(min=0)18         y_pred = self.linear2(h_relu)19         return y_pred20 21 N, D_in, H, D_out = 64, 1000, 100, 1022 x = Variable(torch.randn(N, D_in))23 y = Variable(torch.randn(N, D_out))24 25 # When you need to load custom data, just write your own Dataset class26 loader = DataLoader(TensorDataset(x, y), batch_size=8)27 28 model = TwoLayerNet(D_in, H, D_out)29 30 criterion = torch.nn.MSELoss(size_average=False)31 optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)32 for epoch in range(10):33     # Iterate(遍历) over loader to form minibatches34     for x_batch, y_batch in loader:35         # Loader gives Tensors so you need to wrap in Variables36         x_var, y_var = Variable(x), Variable(y)37         y_pred = model(x_var)38         loss = criterion(y_pred, y_var)39 40         optimizer.zero_grad()41         loss.backward()42         optimizer.step()

 

转载于:https://www.cnblogs.com/Joyce-song94/p/7220102.html

你可能感兴趣的文章
【mysql的设计与优化专题(5)】慢查询详解
查看>>
Linux 文件目录管理的指令
查看>>
opencv初学习-椒盐噪声-中值滤波-均值滤波-腐蚀膨胀
查看>>
笔记70 Spring Boot快速入门(八)(重要)
查看>>
LeetCode 160 Intersection of Two Linked Lists
查看>>
瀑布流布局
查看>>
log4j教程 5、示例程序
查看>>
《Effective C#》读书笔记
查看>>
解决linux服务器上matplotlib中文显示乱码问题
查看>>
“新零售”个人理解
查看>>
win键盘映射成mac键盘
查看>>
妙色王因缘经
查看>>
Oracle之sql语句优化
查看>>
使用http-server开启一个本地服务器
查看>>
FineUIMvc随笔(3)不能忘却的回发(__doPostBack)
查看>>
Python【每日一问】04
查看>>
php CI框学习整理
查看>>
使用Netty,我们到底在开发些什么?
查看>>
hihocoder #1456 : Rikka with Lattice(杜教筛)
查看>>
基础数论复习
查看>>