bc的Pytorch Cheatsheet
一、创建一个新项目
data & dataset & dataloader
第一步,下载数据集。可能是 .txt 或者 .csv 格式文件,查看数据集中的变量名和格式。
第二步,创建 torch.utils.data.Dataset 类型对象。一共有三个必须实现的成员函数,分别是 _init(), _getitem(), _len_()。
- _init_(self, *args):初始化,将数据 load 到内存里,可能还需要进行一些处理。
- _getitem_(self, idx):idx 是一个 int 值,表示需要索引的位置,该函数应该返回一个 data sample,例如,将分子的 Chem.rdchem.Mol 转为 torch_geometric.data.Data 类型对象,会被Dataloader调用。
第三步,创建 torch.utils.data.Dataloader 类型对象。一般选择传入的成员变量是: - dataset: compulsory, Dataset.
- batch_size: optional, int. default = 1.
- shuffle: optional, int. default = False.
- num_workers: optional, int. default = 0. 含义是加载数据过程中子进程的数量,默认数据加载到主进程当中。
- collate_fn: Callable, optional. 合并一个 mini-batch 的成员函数。
重点关注 collate_fn 的实现:Dataset 的 _getitem_ 返回的不同 sample 的数据可能不是等长的,需要在该函数中进行截断和填补。
model
创建一个 model 的对象,可能需要 load 原始参数。
train/eval
第一步,设置模型的状态,例如,model.train() 或者 model.eval()。
第二步,enumerate(dataloader) 获得 batch 数据。
第三步,训练前的准备工作。dataloader 中的数据一般存在于 CPU 上,需要通过 .to(device) 将数据转移到 GPU 上。以及,在 train 函数里,每次梯度回传前,需要将 optimizer 的梯度清空,即 optimizer.zero_grad()。
第四步,计算 loss 和其他评价的指标,比如 accuracy 等。
第五步,在训练的时候,梯度回传,train_loss.backward() + optimizer.step()。
上手一个项目
项目重点看几个部分,首先要把握项目的主脉络
第一,项目依赖,data、model、configuration等;
第二,项目如何启动,bash文件;
第三,项目的入口文件;
第四,项目的状态管理文件,logs等。
其次是读懂代码的Trick:
项目代码中一般不会解释一个变量的内容是什么,可以打印tensor.size()辅助理解。
二、Pytorch语法 Cheatsheet (TODO)
tensor相关处理
避免内存泄漏的心得
三、辅助工具
wandb
登录与创建项目
1 | import wandb |
打印中间信息
可以打印的内容有 A dict of serializable python objects i.e str, ints, floats, Tensors, dicts, or any of the wandb.data_types.
1 | wandb.log(data={"dict_key":"dict_value"}) |
tracemalloc
五一节别人都在旅游快乐打卡,而我却与CPU内存泄漏恶战了三天。
使用tracemalloc可以查看发生内存泄漏的位置。
1 | import tracemalloc |