Pytorch Lightning快速入门

PyTorch Lightning是一个轻量级的PyTorch深度学习框架,旨在简化和规范深度学习模型的训练过程。它提供了一组模块和接口,使用户能够更容易地组织和训练模型,同时减少样板代码的数量。PyTorch Lightning的设计目标是提高代码的可读性和可维护性,同时保持灵活性。它通过将训练循环的组件拆分为独立的模块(如模型、优化器、调度器等),以及提供默认实现来简化用户代码。这使得用户可以专注于模型的定义和高级训练策略,而不必处理底层训练循环的细节。

Pytorch Lightning的官方项目地址为:Lightning-AI/pytorch-lightning。可以通过pip或者conda进行安装:

1
2
3
pip install lightning

conda install lightning -c conda-forge

Quick Start

在一个基础的深度学习流程中,我们会经历数据准备、模型搭建以及模型训练过程,在Pytorch中,它们分别对应DataLoader、Model以及Optimizer。Pytorch lightning则将这些组件进一步拆分成独立的模块,在保持灵活性的同时还提供可读性和可维护性。接下来我们以一个简单的案例来介绍Pytorch lightning的使用。

数据准备

同样首先我们需要进行数据准备,数据准备方面与Pytorch基本没有区别,只需要准备好对应的Dataset和DataLoader即可。

1
2
3
4
5
6
7
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch import utils

# setup data
dataset = MNIST('./data', download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

模型定义

在Pytorch中,模型Model的定义需要继承nn.Module。而在Pytorch lightning中,模型均为lightning.LightningModule的子类。在LightningModule中,提供了许多接口,分别对应模型训练步骤、测试步骤、评估步骤、优化器配置等等。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from torch import optim, nn
import lightning as L
import torch.nn.functional as F

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# define the LightningModule
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder

def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

在这里,我们定义了一个简单的网络,包含一个encoder和一个decoder,分别由多层MLP组成,一起组成了一个autoencoder。在LightningModule中,提供许多接口方法,我们需要实现这些方法。例如在这里,我们首先实现了training_stepconfigure_optimizers方法,分别完成了训练步骤的定义和优化器配置。

其中,trainint_step方法参数中的batch和batch_idx,实际上就对应DataLoader每次迭代会提供的batch和idx。

模型训练

完成了模型定义和数据准备之后,就可以进行模型训练了。在Pytorch Lightning中,使用Trainer来完成模型训练的过程。

1
2
3
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

上面我们首先实例化了一个Trainer,其中指定了相关的参数,然后调用fit方法,开始模型的训练。在模型训练之前,Pytorch Lightning会自动检测系统可用的加速设备并进行使用,无需我们显式地调用.to(device)等。如果想要支持多设备训练,也只需要在实例化Trainer的时候指定相关参数即可。

在实际训练过程中,可再现性是非常重要的。深度学习过程中具有非常多的随机性,为了确保每次运行的可重复性,我们需要手动设置随机数种子,并在Trainer中设置determinstic标志。

1
2
3
4
5
6
from lightning.pytorch import Trainer, seed_everything

seed_everything(42, workers=True)
# sets seeds for numpy, torch and python.random.
model = Model()
trainer = Trainer(deterministic=True)

默认情况下,在训练过程中,会保存一个存储了最近训练状态的checkpoint,我们可以将其加载并且得到一个模型Model。

1
2
3
4
5
6
7
8
9
10
11
12
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

模型评估和测试

通常,在深度学习的流程当中,还会有模型评估和测试的参与。这里我们也可以补充这一部分。当然首先需要完成对应数据集的准备。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms

# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root='./data', download=True, train=True, transform=transform)
test_set = datasets.MNIST(root='./data', download=True, train=False, transform=transform)

# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

在LightningModule中也提供了可选的validation_steptest_step方法,我们可以实现这些接口方法来实现模型评估和测试。其中方法参数与training_step类似,都是直接从DataLoader中得到的batch和idx。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# define the LightningModule
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder

def training_step(self, batch, batch_idx):
...

def test_step(self, batch, batch_idx):
# this is the test loop
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
test_loss = nn.functional.mse_loss(x_hat, x)
self.log("test_loss", test_loss)
return test_loss

def validation_step(self, batch, batch_idx):
# this is the validation loop
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
val_loss = nn.functional.mse_loss(x_hat, x)
self.log("val_loss", val_loss)
return val_loss

def configure_optimizers(self):
...

在默认情况下,模型评估是在每一个epoch训练完成之后进行的。

1
2
3
4
5
6
7
8
9
10
from torch.utils.data import DataLoader

train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)

model = LitAutoEncoder(encoder, decoder)

# train with both splits
trainer = L.Trainer(limit_train_batches=1000, limit_val_batches=800, max_epochs=2)
trainer.fit(model, train_loader, valid_loader)

而模型测试对应是在模型训练完成之后进行。它对应的方法是test,在其中提供测试数据对应的DataLoader。

1
2
3
4
5
from torch.utils.data import DataLoader

trainer = L.Trainer()
model = LitAutoEncoder(encoder, decoder)
trainer.test(model, dataloaders=DataLoader(test_set))

LightningModule

LightningModules是Pytorch Lightning的一大核心类。LightningModule主要在6个方面来组织Pytorch代码,它们分别对应不同接口方法。

  • init初始化:对应__init__setup方法
  • train loop:对应training_step方法
  • validation loop:对应validation_step方法
  • test loop:对应test_step方法
  • prediction loop:对应predict_step方法
  • optimizers and LR schedulers:对应configure_optimizers方法

Training

在LightningModule中我们可以重写forward方法。使用这个方法可以让我们以直接调用的模式调用模型,类似nn.Module,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from torch import optim, nn
import lightning as L
import torch.nn.functional as F

encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder

def forward(self, inputs):
return self.decoder(self.encoder(inputs))

借用forward方法,我们也可以简化training_step的书写。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder

def forward(self, inputs):
return self.decoder(self.encoder(inputs))

def training_step(self, batch, batch_idx):
inputs, targets = batch
inputs = inputs.view(inputs.size(0), -1)
outputs = self(inputs)
loss = nn.functional.mse_loss(outputs, inputs)
self.log("train_loss", loss)
return loss

训练循环对应training_step方法,该方法需要返回一个loss。重写该方法之后,对应执行的逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# put model in train mode and enable gradient calculation
model.train()
torch.set_grad_enabled(True)

for batch_idx, batch in enumerate(train_dataloader):
loss = training_step(batch, batch_idx)

# clear gradients
optimizer.zero_grad()

# backward
loss.backward()

# update parameters
optimizer.step()

如果希望记录epoch级别的metric,则可以调用self.log()方法。该方法会接受两个参数,分别是key和value,其中value会被组织成list的形式,最终返回的是list上的平均作为最终的epoch metric。例如在上面的training_step中我们调用了self.log方法,分别传入train_loss和loss,最后就会得到一个epoch级别的train_loss,它是通过在epoch中记录的所有loss计算平均得到的。类似于下面的逻辑:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
outs = []
for batch_idx, batch in enumerate(train_dataloader):
# forward
loss = training_step(batch, batch_idx)
outs.append(loss.detach())

# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# update parameters
optimizer.step()

# note: in reality, we do this incrementally, instead of keeping all outputs in memory
epoch_metric = torch.mean(torch.stack(outs))

Validation

模型评估对应的是validation_step方法:

1
2
3
4
5
6
7
def validation_step(self, batch, batch_idx):
inputs, targets = batch
inputs = inputs.view(inputs.size(0), -1)
outputs = self(inputs)
loss = nn.functional.mse_loss(outputs, inputs)
self.log("val_loss", loss)
return loss

实际对应伪代码逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# ...
for batch_idx, batch in enumerate(train_dataloader):
loss = model.training_step(batch, batch_idx)
loss.backward()
# ...

if validate_at_some_point:
# disable grads + batchnorm + dropout
torch.set_grad_enabled(False)
model.eval()

# ----------------- VAL LOOP ---------------
for val_batch_idx, val_batch in enumerate(val_dataloader):
val_out = model.validation_step(val_batch, val_batch_idx)
# ----------------- VAL LOOP ---------------

# enable grads + batchnorm + dropout
torch.set_grad_enabled(True)
model.train()

在Trainer中也提供validate接口调用在val datalodar上的验证逻辑:

1
2
trainer = L.Trainer()
trainer.validate(model)

Testing

模型测试对应的是test_step方法:

1
2
3
4
5
6
7
def test_step(self, batch, batch_idx):
inputs, targets = batch
inputs = inputs.view(inputs.size(0), -1)
outputs = self(inputs)
loss = nn.functional.mse_loss(outputs, inputs)
self.log("test_loss", loss)
return loss

测试循环仅仅在Trainer的test方法中被调用

1
2
3
trainer = L.Trainer()
model = LitAutoEncoder(encoder, decoder)
trainer.test(model, dataloaders=test_dataloader)

Inference

模型推理即表示模型对给定输入如何进行处理并得到输出,对应方法predict_step

1
2
3
4
def predict_step(self, batch):
inputs, targets = batch
inputs = inputs.view(inputs.size(0), -1)
return self(inputs)

对应的伪代码逻辑如下:

1
2
3
4
5
6
7
8
# disable grads + batchnorm + dropout
torch.set_grad_enabled(False)
model.eval()
all_preds = []

for batch_idx, batch in enumerate(predict_dataloader):
pred = model.predict_step(batch, batch_idx)
all_preds.append(pred)

可以通过Trainer的predict方法直接调用:

1
2
trainer = L.Trainer()
predictions = trainer.predict(autoencoder, dataloaders=train_loader)
  • 默认情况下, predict_step()执行的是forward()方法。如果我们想用forward方法对应的模型直接调用并进行推理,在代码中需要显式指定禁止梯度回传。
1
2
3
4
model.eval()
with torch.no_grad():
batch = dataloader.dataset[0]
pred = model(batch)

Save and Load Model

Pytorch lightning中提供lightning checkpoint来保存对象状态。lightning checkpoint不是仅仅保存了模型当前的参数,而是基本保存了模型完整的内部状态,包括当前epoch、global step、模型的state_dict、optimzer状态、scheduler状态等等。利用lightning checkpoint,我们可以完全无缝地恢复全部运行状态。

Trainer() 会默认在当前目录下保存一个checkpoint,记录上一个训练周期的状态。这点可以帮助我们在训练中断的时候及时恢复训练。checkpoint保存的位置可以通过可以通过default_root_dir参数来修改。如果想要禁用该功能,则可以通过enable_checkpointing=False来控制。

1
2
trainer = Trainer(default_root_dir='your/path/here')
trainer_no_checkpoint = Trainer(enable_checkpointing=False)

在LightningModule中提供self.save_hyperparameters()方法来保存所有在init构造方法中传入的超参数。在保存的checkpoint文件中,该部分参数被存储在hyper_parameters键值下。

1
2
3
4
class MyLightningModule(LightningModule):
def __init__(self, learning_rate, another_parameter, *args, **kwargs):
super().__init__()
self.save_hyperparameters()
  • self.save_hyperparameters()中提供ignore参数,该参数表示需要忽略哪些参数。对应参数无法通过checkpoint加载得到,需要手动指定

通过load_from_checkpint方法,可以从checkpoint中加载LightningModule模型。此时如果checkpoint中包含hyper_parameters,该部分value也会被加载进来,之后可以直接访问。当然也可以不使用checkpoint中的参数,在方法中指定的参数具有更高的优先级。

1
2
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
model2 = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt", learning_rate=1e-6)

Lightning checkpoint的格式与torch原始的nn.module是一致的,因此也可以用lightning checkpoint来加载nn.module.

前面我们提到lightning checkpoint中不仅仅记录了模型的信息,还记录了几乎所有运行状态信息。这允许我们直接通过lightning checkpoint来恢复运行状态,代码如下:

1
2
3
4
5
model = MyLightningModule()
trainer = Trainer()

# automatically restores model, epoch, step, LR schedulers, etc...
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")

Hooks

在LightningModule中提供了许多Hooks钩子函数,通过实现这些函数,我们可以在fit()的流程中实现许多自定义的逻辑,在流程不同阶段注入用户自定义代码,这也是Pytorch Lightning灵活性的体现。详细的函数说明和流程说明可以查看官方文档LightningModule|Hooks — PyTorch Lightning documentation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# runs on every device: devices can be GPUs, TPUs, ...
def fit(self):
configure_callbacks()

if local_rank == 0:
prepare_data()

setup("fit")
configure_model()
configure_optimizers()

on_fit_start()

# the sanity check runs here

on_train_start()
for epoch in epochs:
fit_loop()
on_train_end()

on_fit_end()
teardown("fit")


def fit_loop():
torch.set_grad_enabled(True)

on_train_epoch_start()

for batch in train_dataloader():
on_train_batch_start()

on_before_batch_transfer()
transfer_batch_to_device()
on_after_batch_transfer()

out = training_step()

on_before_zero_grad()
optimizer_zero_grad()

on_before_backward()
backward()
on_after_backward()

on_before_optimizer_step()
configure_gradient_clipping()
optimizer_step()

on_train_batch_end(out, batch, batch_idx)

if should_check_val:
val_loop()

on_train_epoch_end()


def val_loop():
on_validation_model_eval() # calls `model.eval()`
torch.set_grad_enabled(False)

on_validation_start()
on_validation_epoch_start()

for batch_idx, batch in enumerate(val_dataloader()):
on_validation_batch_start(batch, batch_idx)

batch = on_before_batch_transfer(batch)
batch = transfer_batch_to_device(batch)
batch = on_after_batch_transfer(batch)

out = validation_step(batch, batch_idx)

on_validation_batch_end(out, batch, batch_idx)

on_validation_epoch_end()
on_validation_end()

# set up for train
on_validation_model_train() # calls `model.train()`
torch.set_grad_enabled(True)

fit()流程是LightningModule的核心,它将模型训练的全流程进行了抽象,并且将每个抽象对应的实现都作为LightningModule的Hooks。PytorchLightning在调用fit的时候,背后就是按照上述流程进行执行的,其中每个抽象函数将会替换为具体的实现。

可以说这个抽象流程就是理解PytorchLightning的核心,而对应的抽象方法在用到的时候再查询相关文档即可。

Trainer

Trainer是Pytorch Lightning中另一大核心类。Trainer中提供了许多模版流程,包括训练、验证、测试、推理等,这些流程是对应操作的抽象组织,在真实调用的时候,会使用在LightningModule中实际定义的Hooks函数。Trainer的好处在于它帮助我们封装了许多繁琐操作,这些操作与模型无关,例如在不同阶段开启和禁用梯度、运行训练,验证和测试DataLoader、将数据放置在正确的设备上等等。另外,Trainer还允许用户定义不同的callback函数,而Trainer会自动在对应的时机执行对应的callback,类似于事件绑定。Trainer中常用的参数可以在官方文档Trainer — PyTorch Lightning documentation中找到。

DataModule

虽然Pytorch Lightning支持直接使用DataLoader,但是它还是提供了对数据的封装,DataModule。DataModule是一个可共享、可重用的类,其中封装了处理数据所需要的所有步骤。包括:

  1. 数据下载/tokenize/process处理
  2. 数据清洗/数据存储
  3. 数据集Dataset加载
  4. 数据转换transform
  5. 包装成Dataloader

在传统的Pytorch代码中,数据处理的代码可能分散在各个不同的地方,没有规范的管理,造成重用的困难。DataModule主要是为了解决这一问题,它将数据处理相关的代码进行管理和封装,使得代码共享和重用变得更加简单。

DataModule的引入主要还是为了对数据处理的逻辑进行抽象,这样在进行模型开发的时候,只需要对接DataModule,至于其他的逻辑,就交给DataModule具体实现。可以说DataModule只是对数据处理相关逻辑进行重新组织,使其向外暴露统一规范的接口。

DataModule在Pytorch Lightning中对应的类是LightningDataModule。下面是一个DataModule的代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import lightning as L
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms


class MNISTDataModule(L.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)

def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)

# Assign test dataset for use in dataloader(s)
if stage == "test":
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

if stage == "predict":
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)

def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)

与LightningModule类似,DataModule需要我们实现如下特定的接口:

  • prepare_data: 数据准备,完成数据的下载、保存等操作。该函数由主进程main调用,并且仅在cpu上运行
  • setup: 在prepare_data执行完成之后执行。这里主要是存放其他可能在GPU上执行的操作,例如计算类别数目、构建词汇表、进行数据划分、创建datasets、执行数据转换等
  • tarin_dataloader: 生成(返回)训练数据的DataLoader
  • val_dataloader: 生成(返回)评估数据的DataLoader
  • test_dataloader: 生成(返回)测试数据的DataLoader
  • predict_dataloader: 生成(返回)预测数据的DataLoader

如果我们定义好了一个DataModule,那么就可以在Trainer中进行使用,如下所示:

1
2
3
4
5
mnist_datamodule = MNISTDataModule(my_path)
model = LitAutoEncoder()
trainer = Trainer()

trainer.fit(model, datamodule=mnist_datamodule)

其余相关API可以参考文档:LightningDataModule — PyTorch Lightning documentation

Other Features

CallBacks

Reference: Callback — PyTorch Lightning documentation

Trainer允许用户定义不同的callback函数,这些callback函数会在对应的时机被执行,可以类比AOP面向切面编程,也可以类比事件绑定。通常来说,Callback函数是与某个特定的训练过程无关的,独立于训练逻辑的“非必要”代码。

使用Callback,需要我们继承Callback类。例如在下面的代码中,我们使用一个Callback,在训练开始和结束的时候分别输出不同的提示信息:

1
2
3
4
5
6
7
8
9
10
11
12
from lightning.pytorch.callbacks import Callback


class MyPrintingCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training is starting")

def on_train_end(self, trainer, pl_module):
print("Training is ending")


trainer = Trainer(callbacks=[MyPrintingCallback()])

Callback函数可以类比LightningModule中的Hook函数,区别在于Hook函数通常是模型训练、测试和评估等抽象逻辑所必需的函数,而Callback函数并不是必须的。对于那些非必需的业务操作,我们就可以将它实现在Callback中,而不会污染LightningModule的代码,保证了各个模块的独立性。

Early Stopping

在模型训练的过程中,可能在中途达到了某些指标,此时我们希望能够让它停止训练。这是一个非常常见的需求,在传统的Pytorch代码中,我们需要自己控制条件判断来达到这个效果,而在PytorchLightning中,我们可以非常方便的做到。

第一种方式是使用LightningModule提供的其中一个Hook on_train_batch_start()。该Hook方法会在每个epoch中,每个batch执行训练之前调用。如果该方法返回-1,那么就会跳过当前这个epoch的训练。因此我们可以在这个Hook方法中做对应early stopping的判断,做到epoch的跳过。由于每个epoch都会有相同的逻辑,因此后续训练过程中的所有epoch也会重复跳过,最终停止所有的训练逻辑。

另一种方式使用EarlyStopping CallBack,它能够监控某个指标,当观察到这个指标没有提升,就会停止训练。

要使用该Callbak,我们首先需要在validation_step即验证过程中记录对应的验证指标,例如val_loss。之后在Trainer中指定使用EarlyStopping Callback,在其中指定monitormode参数,表示需要监控的指标以及以何种方式监控。

1
2
3
4
5
6
7
8
9
10
11
12
from lightning.pytorch.callbacks.early_stopping import EarlyStopping


class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
loss = ...
self.log("val_loss", loss)


model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)

Experiment Track

Reference: Track and Visualize Experiments — PyTorch Lightning documentation

在模型训练的过程中,指标的记录和跟踪是非常重要的。通过记录指标,我们可以可视化模型的学习过程,以便我们做出相应的调整。利用Pytorch Lightning,我们能够做到多种指标的监控,包括数值、图像、音频甚至视频等。

首先是最简单的数值指标的监控。这个只需要在LightningModule类内部调用self.log()即可,在调用的同时指定key和value。如果要一次性记录多个指标,则可以使用self.log_dict(),其中提供对应的metric字典。

1
2
3
4
5
6
7
8
class LitModel(L.LightningModule):
def training_step(self, batch, batch_idx):
# log单个metric
value = ...
self.log("some_value", value)
# log多个metric
values_dict = {"loss": loss, "acc": acc, "metric_n": metric_n}
self.log_dict(values_dict)

Pytorch Lightning还支持集成第三方的实验日志管理器,例如tensorboad,wandb等。例如,如果在运行环境中还安装了tensorboard,上面的指标默认也会呈现在tensorboard中;通过lightning.pytorch.loggers中提供的WandbLogger,我们就可以直接调用原始wandb中的API。

要在训练过程中使用这些日志记录器,只需要得到对应的logger对象,然后在Trainer初始化的时候传入即可,之后在其他地方,则可以通过self.logger来访问

1
2
3
4
5
6
7
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

logger1 = TensorBoardLogger()
logger2 = WandbLogger()

trainer1 = Trainer(logger=[logger1, logger2])
trainer2 = Trainer(logger=logger2)
1
2
3
4
5
6
7
8
9
class MyModule(LightningModule):
def any_lightning_module_function_or_hook(self):
tensorboard_logger = self.loggers.experiment[0]
wandb_logger = self.loggers.experiment[1]

fake_images = torch.Tensor(32, 3, 28, 28)

tensorboard_logger.add_image("generated_images", fake_images, 0)
wandb_logger.add_image("generated_images", fake_images, 0)

Project Template

虽然Pytorch Lightning的目的就是在于能够让深度学习代码复用和管理变得简单,但是如果没有很好的使用规范,同样会陷入管理困难,重用复杂的困境。尤其是Pytorch Lightning在Pytroch的基础上引入了更多的抽象和封装,这就像一柄双刃剑,如果不能很好地组织和管理,反而会带来更多的复杂度。

这里介绍一个结合了Hydra和Pytorch Lightning的项目框架:ashleve/lightning-hydra-template: PyTorch Lightning + Hydra. A very user-friendly template for ML experimentation. ⚡🔥⚡


Pytorch Lightning快速入门
http://example.com/2024/01/19/Pytorch-Lightning快速入门/
作者
EverNorif
发布于
2024年1月19日
许可协议