import torch
from torch.utils import data as Data
# 继承Dataset类,需要实现__init__,__getitem__,__len__
class DataAdapter(Data.Dataset):
def __init__(self,X,Y):
super(DataAdapter,self).__init__()
# 转换为tensor格式
self.X = torch.FloatTensor(X)
self.Y = torch.FloatTensor(Y)
def __getitem__(self,index):
return self.X[index,:],self.Y[index]
def __len__(self):
return len(self.X)
# 设置batch_size和训练集的比例
def get_data_loader(batch_size, train_split):
# 下面几行代码是读入数据集
# df = pd.read_csv(...)
# X = df.drop('label', axis=1)
# y = df['label']
dataset = DataAdapter(X, y)
train_size = int(len(X) * train_split)
valid_size = len(X) - train_size
# 随机切分,训练集和验证集
# train_dataest 和 valid_dataset 的类别是 torch.utils.data.dataset.Subset
train_dataset,valid_dataset = Data.random_split(dataset,[train_size,valid_size]) # 随机划分训练集和验证集
# 构造 DataLoader,数据集,batch_size,是否shuffle
train_loader = Data.DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers = 0) # 加载DataLoader
valid_loader = Data.DataLoader(valid_dataset, batch_size = batch_size, shuffle=True, num_workers = 0)
return train_loader,valid_loader
# 遍历DataLoader
for i,data in enumerate(train_loader,0):
X_batch, y_batch = data[0].to(device),data[1].to(device) # 获取数据
y_pred_batch = model(inputs)
loss = criterion(y_pred_batch,y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
最后修改:2021 年 08 月 05 日 11 : 14 PM
© 允许规范转载