import torch
from torch.utils import data as Data

# 继承Dataset类,需要实现__init__,__getitem__,__len__
class DataAdapter(Data.Dataset):

    def __init__(self,X,Y):
        super(DataAdapter,self).__init__()
        # 转换为tensor格式
        self.X = torch.FloatTensor(X)
        self.Y = torch.FloatTensor(Y)
    
    def __getitem__(self,index):
        return self.X[index,:],self.Y[index]

    def __len__(self):
        return len(self.X)

    # 设置batch_size和训练集的比例
def get_data_loader(batch_size, train_split):
    # 下面几行代码是读入数据集
    # df = pd.read_csv(...)
    # X = df.drop('label', axis=1)
    # y = df['label']
    dataset = DataAdapter(X, y)
    train_size = int(len(X) * train_split)
    valid_size = len(X) - train_size
    
    # 随机切分,训练集和验证集
    # train_dataest 和 valid_dataset 的类别是 torch.utils.data.dataset.Subset
    train_dataset,valid_dataset = Data.random_split(dataset,[train_size,valid_size])  # 随机划分训练集和验证集
    
    # 构造 DataLoader,数据集,batch_size,是否shuffle
    train_loader = Data.DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers = 0) # 加载DataLoader
    valid_loader = Data.DataLoader(valid_dataset, batch_size = batch_size, shuffle=True, num_workers = 0)
    return train_loader,valid_loader

# 遍历DataLoader
for i,data in enumerate(train_loader,0):
    X_batch, y_batch = data[0].to(device),data[1].to(device) # 获取数据
    y_pred_batch = model(inputs)
    loss = criterion(y_pred_batch,y_batch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
最后修改:2021 年 08 月 05 日 11 : 14 PM
如果觉得我的文章对你有用,请随意赞赏