本文最后更新于:2024-07-06T21:41:01+08:00
PyTorch
Lightning是一个轻量级的PyTorch深度学习框架,旨在简化和规范深度学习模型的训练过程。它提供了一组模块和接口,使用户能够更容易地组织和训练模型,同时减少样板代码的数量。PyTorch
Lightning的设计目标是提高代码的可读性和可维护性,同时保持灵活性。它通过将训练循环的组件拆分为独立的模块(如模型、优化器、调度器等),以及提供默认实现来简化用户代码。这使得用户可以专注于模型的定义和高级训练策略,而不必处理底层训练循环的细节。
Pytorch Lightning的官方项目地址为:Lightning-AI/pytorch-lightning 。可以通过pip或者conda进行安装:
1 2 3 pip install lightning conda install lightning -c conda-forge
Quick Start
在一个基础的深度学习流程中,我们会经历数据准备、模型搭建以及模型训练过程,在Pytorch中,它们分别对应DataLoader、Model以及Optimizer。Pytorch
lightning则将这些组件进一步拆分成独立的模块,在保持灵活性的同时还提供可读性和可维护性。接下来我们以一个简单的案例来介绍Pytorch
lightning的使用。
数据准备
同样首先我们需要进行数据准备,数据准备方面与Pytorch基本没有区别,只需要准备好对应的Dataset和DataLoader即可。
1 2 3 4 5 6 7 from torchvision.datasets import MNISTfrom torchvision.transforms import ToTensorfrom torch import utils dataset = MNIST('./data' , download=True , transform=ToTensor()) train_loader = utils.data.DataLoader(dataset)
模型定义
在Pytorch中,模型Model的定义需要继承nn.Module
。而在Pytorch
lightning中,模型均为lightning.LightningModule
的子类。在LightningModule中,提供了许多接口,分别对应模型训练步骤、测试步骤、评估步骤、优化器配置等等。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 import torchfrom torch import optim, nnimport lightning as Limport torch.nn.functional as F encoder = nn.Sequential(nn.Linear(28 * 28 , 64 ), nn.ReLU(), nn.Linear(64 , 3 )) decoder = nn.Sequential(nn.Linear(3 , 64 ), nn.ReLU(), nn.Linear(64 , 28 * 28 ))class LitAutoEncoder (L.LightningModule): def __init__ (self, encoder, decoder ): super ().__init__() self.encoder = encoder self.decoder = decoder def training_step (self, batch, batch_idx ): x, y = batch x = x.view(x.size(0 ), -1 ) z = self.encoder(x) x_hat = self.decoder(z) loss = nn.functional.mse_loss(x_hat, x) self.log("train_loss" , loss) return loss def configure_optimizers (self ): optimizer = optim.Adam(self.parameters(), lr=1e-3 ) return optimizer autoencoder = LitAutoEncoder(encoder, decoder)
在这里,我们定义了一个简单的网络,包含一个encoder和一个decoder,分别由多层MLP组成,一起组成了一个autoencoder。在LightningModule中,提供许多接口方法,我们需要实现这些方法。例如在这里,我们首先实现了training_step
和configure_optimizers
方法,分别完成了训练步骤的定义和优化器配置。
其中,trainint_step
方法参数中的batch和batch_idx,实际上就对应DataLoader每次迭代会提供的batch和idx。
模型训练
完成了模型定义和数据准备之后,就可以进行模型训练了。在Pytorch
Lightning中,使用Trainer来完成模型训练的过程。
1 2 3 trainer = L.Trainer(limit_train_batches=100 , max_epochs=1 ) trainer.fit(model=autoencoder, train_dataloaders=train_loader)
上面我们首先实例化了一个Trainer,其中指定了相关的参数,然后调用fit
方法,开始模型的训练。在模型训练之前,Pytorch
Lightning会自动检测系统可用的加速设备并进行使用,无需我们显式地调用.to(device)
等。如果想要支持多设备训练,也只需要在实例化Trainer的时候指定相关参数即可。
在实际训练过程中,可再现性是非常重要的。深度学习过程中具有非常多的随机性,为了确保每次运行的可重复性,我们需要手动设置随机数种子,并在Trainer中设置determinstic
标志。
1 2 3 4 5 6 from lightning.pytorch import Trainer, seed_everything seed_everything(42 , workers=True ) model = Model() trainer = Trainer(deterministic=True )
默认情况下,在训练过程中,会保存一个存储了最近训练状态的checkpoint,我们可以将其加载并且得到一个模型Model。
1 2 3 4 5 6 7 8 9 10 11 12 checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt" autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder) encoder = autoencoder.encoder encoder.eval () fake_image_batch = torch.rand(4 , 28 * 28 , device=autoencoder.device) embeddings = encoder(fake_image_batch)print ("⚡" * 20 , "\nPredictions (4 image embeddings):\n" , embeddings, "\n" , "⚡" * 20 )
模型评估和测试
通常,在深度学习的流程当中,还会有模型评估和测试的参与。这里我们也可以补充这一部分。当然首先需要完成对应数据集的准备。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import torch.utils.data as datafrom torchvision import datasetsimport torchvision.transforms as transforms transform = transforms.ToTensor() train_set = datasets.MNIST(root='./data' , download=True , train=True , transform=transform) test_set = datasets.MNIST(root='./data' , download=True , train=False , transform=transform) train_set_size = int (len (train_set) * 0.8 ) valid_set_size = len (train_set) - train_set_size seed = torch.Generator().manual_seed(42 ) train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
在LightningModule中也提供了可选的validation_step
和test_step
方法,我们可以实现这些接口方法来实现模型评估和测试。其中方法参数与training_step类似,都是直接从DataLoader中得到的batch和idx。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 class LitAutoEncoder (L.LightningModule): def __init__ (self, encoder, decoder ): super ().__init__() self.encoder = encoder self.decoder = decoder def training_step (self, batch, batch_idx ): ... def test_step (self, batch, batch_idx ): x, y = batch x = x.view(x.size(0 ), -1 ) z = self.encoder(x) x_hat = self.decoder(z) test_loss = nn.functional.mse_loss(x_hat, x) self.log("test_loss" , test_loss) return test_loss def validation_step (self, batch, batch_idx ): x, y = batch x = x.view(x.size(0 ), -1 ) z = self.encoder(x) x_hat = self.decoder(z) val_loss = nn.functional.mse_loss(x_hat, x) self.log("val_loss" , val_loss) return val_loss def configure_optimizers (self ): ...
在默认情况下,模型评估是在每一个epoch训练完成之后进行的。
1 2 3 4 5 6 7 8 9 10 from torch.utils.data import DataLoader train_loader = DataLoader(train_set) valid_loader = DataLoader(valid_set) model = LitAutoEncoder(encoder, decoder) trainer = L.Trainer(limit_train_batches=1000 , limit_val_batches=800 , max_epochs=2 ) trainer.fit(model, train_loader, valid_loader)
而模型测试对应是在模型训练完成之后进行。它对应的方法是test
,在其中提供测试数据对应的DataLoader。
1 2 3 4 5 from torch.utils.data import DataLoader trainer = L.Trainer() model = LitAutoEncoder(encoder, decoder) trainer.test(model, dataloaders=DataLoader(test_set))
LightningModule
LightningModules是Pytorch
Lightning的一大核心类。LightningModule主要在6个方面来组织Pytorch代码,它们分别对应不同接口方法。
init初始化:对应__init__
和setup
方法
train loop:对应training_step
方法
validation loop:对应validation_step
方法
test loop:对应test_step
方法
prediction loop:对应predict_step
方法
optimizers and LR
schedulers:对应configure_optimizers
方法
Training
在LightningModule中我们可以重写forward方法。使用这个方法可以让我们以直接调用的模式调用模型,类似nn.Module
,如下所示:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import torchfrom torch import optim, nnimport lightning as Limport torch.nn.functional as F encoder = nn.Sequential(nn.Linear(28 * 28 , 64 ), nn.ReLU(), nn.Linear(64 , 3 )) decoder = nn.Sequential(nn.Linear(3 , 64 ), nn.ReLU(), nn.Linear(64 , 28 * 28 ))class LitAutoEncoder (L.LightningModule): def __init__ (self, encoder, decoder ): super ().__init__() self.encoder = encoder self.decoder = decoder def forward (self, inputs ): return self.decoder(self.encoder(inputs))
借用forward方法,我们也可以简化training_step的书写。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 class LitAutoEncoder (L.LightningModule): def __init__ (self, encoder, decoder ): super ().__init__() self.encoder = encoder self.decoder = decoder def forward (self, inputs ): return self.decoder(self.encoder(inputs)) def training_step (self, batch, batch_idx ): inputs, targets = batch inputs = inputs.view(inputs.size(0 ), -1 ) outputs = self(inputs) loss = nn.functional.mse_loss(outputs, inputs) self.log("train_loss" , loss) return loss
训练循环对应training_step
方法,该方法需要返回一个loss。重写该方法之后,对应执行的逻辑如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 model.train() torch.set_grad_enabled(True )for batch_idx, batch in enumerate (train_dataloader): loss = training_step(batch, batch_idx) optimizer.zero_grad() loss.backward() optimizer.step()
如果希望记录epoch级别的metric,则可以调用self.log()
方法。该方法会接受两个参数,分别是key和value,其中value会被组织成list的形式,最终返回的是list上的平均作为最终的epoch
metric。例如在上面的training_step中我们调用了self.log
方法,分别传入train_loss和loss,最后就会得到一个epoch级别的train_loss,它是通过在epoch中记录的所有loss计算平均得到的。类似于下面的逻辑:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 outs = []for batch_idx, batch in enumerate (train_dataloader): loss = training_step(batch, batch_idx) outs.append(loss.detach()) optimizer.zero_grad() loss.backward() optimizer.step() epoch_metric = torch.mean(torch.stack(outs))
Validation
模型评估对应的是validation_step
方法:
1 2 3 4 5 6 7 def validation_step (self, batch, batch_idx ): inputs, targets = batch inputs = inputs.view(inputs.size(0 ), -1 ) outputs = self(inputs) loss = nn.functional.mse_loss(outputs, inputs) self.log("val_loss" , loss) return loss
实际对应伪代码逻辑如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 for batch_idx, batch in enumerate (train_dataloader): loss = model.training_step(batch, batch_idx) loss.backward() if validate_at_some_point: torch.set_grad_enabled(False ) model.eval () for val_batch_idx, val_batch in enumerate (val_dataloader): val_out = model.validation_step(val_batch, val_batch_idx) torch.set_grad_enabled(True ) model.train()
在Trainer中也提供validate接口调用在val datalodar上的验证逻辑:
1 2 trainer = L.Trainer() trainer.validate(model)
Testing
模型测试对应的是test_step
方法:
1 2 3 4 5 6 7 def test_step (self, batch, batch_idx ): inputs, targets = batch inputs = inputs.view(inputs.size(0 ), -1 ) outputs = self(inputs) loss = nn.functional.mse_loss(outputs, inputs) self.log("test_loss" , loss) return loss
测试循环仅仅在Trainer的test方法中被调用
1 2 3 trainer = L.Trainer() model = LitAutoEncoder(encoder, decoder) trainer.test(model, dataloaders=test_dataloader)
Inference
模型推理即表示模型对给定输入如何进行处理并得到输出,对应方法predict_step
:
1 2 3 4 def predict_step (self, batch ): inputs, targets = batch inputs = inputs.view(inputs.size(0 ), -1 ) return self(inputs)
对应的伪代码逻辑如下:
1 2 3 4 5 6 7 8 torch.set_grad_enabled(False ) model.eval () all_preds = []for batch_idx, batch in enumerate (predict_dataloader): pred = model.predict_step(batch, batch_idx) all_preds.append(pred)
可以通过Trainer的predict方法直接调用:
1 2 trainer = L.Trainer() predictions = trainer.predict(autoencoder, dataloaders=train_loader)
默认情况下,
predict_step()执行的是forward()方法。如果我们想用forward方法对应的模型直接调用并进行推理,在代码中需要显式指定禁止梯度回传。
1 2 3 4 model.eval ()with torch.no_grad(): batch = dataloader.dataset[0 ] pred = model(batch)
Save and Load Model
Pytorch lightning中提供lightning checkpoint来保存对象状态。lightning
checkpoint不是仅仅保存了模型当前的参数,而是基本保存了模型完整的内部状态,包括当前epoch、global
step、模型的state_dict、optimzer状态、scheduler状态等等。利用lightning
checkpoint,我们可以完全无缝地恢复全部运行状态。
Trainer()
会默认在当前目录下保存一个checkpoint,记录上一个训练周期的状态。这点可以帮助我们在训练中断的时候及时恢复训练。checkpoint保存的位置可以通过可以通过default_root_dir
参数来修改。如果想要禁用该功能,则可以通过enable_checkpointing=False
来控制。
1 2 trainer = Trainer(default_root_dir='your/path/here' ) trainer_no_checkpoint = Trainer(enable_checkpointing=False )
在LightningModule中提供self.save_hyperparameters()
方法来保存所有在init构造方法中传入的超参数。在保存的checkpoint文件中,该部分参数被存储在hyper_parameters
键值下。
1 2 3 4 class MyLightningModule (LightningModule ): def __init__ (self, learning_rate, another_parameter, *args, **kwargs ): super ().__init__() self.save_hyperparameters()
self.save_hyperparameters()
中提供ignore参数,该参数表示需要忽略哪些参数。对应参数无法通过checkpoint加载得到,需要手动指定
通过load_from_checkpint
方法,可以从checkpoint中加载LightningModule模型。此时如果checkpoint中包含hyper_parameters,该部分value也会被加载进来,之后可以直接访问。当然也可以不使用checkpoint中的参数,在方法中指定的参数具有更高的优先级。
1 2 model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt" ) model2 = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt" , learning_rate=1e-6 )
Lightning
checkpoint的格式与torch原始的nn.module是一致的,因此也可以用lightning
checkpoint来加载nn.module.
前面我们提到lightning
checkpoint中不仅仅记录了模型的信息,还记录了几乎所有运行状态信息。这允许我们直接通过lightning
checkpoint来恢复运行状态,代码如下:
1 2 3 4 5 model = MyLightningModule() trainer = Trainer() trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt" )
Hooks
在LightningModule中提供了许多Hooks钩子函数,通过实现这些函数,我们可以在fit()
的流程中实现许多自定义的逻辑,在流程不同阶段注入用户自定义代码,这也是Pytorch
Lightning灵活性的体现。详细的函数说明和流程说明可以查看官方文档LightningModule|Hooks
— PyTorch Lightning documentation
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 def fit (self ): configure_callbacks() if local_rank == 0 : prepare_data() setup("fit" ) configure_model() configure_optimizers() on_fit_start() on_train_start() for epoch in epochs: fit_loop() on_train_end() on_fit_end() teardown("fit" )def fit_loop (): torch.set_grad_enabled(True ) on_train_epoch_start() for batch in train_dataloader(): on_train_batch_start() on_before_batch_transfer() transfer_batch_to_device() on_after_batch_transfer() out = training_step() on_before_zero_grad() optimizer_zero_grad() on_before_backward() backward() on_after_backward() on_before_optimizer_step() configure_gradient_clipping() optimizer_step() on_train_batch_end(out, batch, batch_idx) if should_check_val: val_loop() on_train_epoch_end()def val_loop (): on_validation_model_eval() torch.set_grad_enabled(False ) on_validation_start() on_validation_epoch_start() for batch_idx, batch in enumerate (val_dataloader()): on_validation_batch_start(batch, batch_idx) batch = on_before_batch_transfer(batch) batch = transfer_batch_to_device(batch) batch = on_after_batch_transfer(batch) out = validation_step(batch, batch_idx) on_validation_batch_end(out, batch, batch_idx) on_validation_epoch_end() on_validation_end() on_validation_model_train() torch.set_grad_enabled(True )
fit()流程是LightningModule的核心,它将模型训练的全流程进行了抽象,并且将每个抽象对应的实现都作为LightningModule的Hooks。PytorchLightning在调用fit的时候,背后就是按照上述流程进行执行的,其中每个抽象函数将会替换为具体的实现。
可以说这个抽象流程就是理解PytorchLightning的核心,而对应的抽象方法在用到的时候再查询相关文档即可。
Trainer
Trainer是Pytorch
Lightning中另一大核心类。Trainer中提供了许多模版流程,包括训练、验证、测试、推理等,这些流程是对应操作的抽象组织,在真实调用的时候,会使用在LightningModule中实际定义的Hooks函数。Trainer的好处在于它帮助我们封装了许多繁琐操作,这些操作与模型无关,例如在不同阶段开启和禁用梯度、运行训练,验证和测试DataLoader、将数据放置在正确的设备上等等。另外,Trainer还允许用户定义不同的callback函数,而Trainer会自动在对应的时机执行对应的callback,类似于事件绑定。Trainer中常用的参数可以在官方文档Trainer
— PyTorch Lightning documentation 中找到。
DataModule
虽然Pytorch
Lightning支持直接使用DataLoader,但是它还是提供了对数据的封装,DataModule。DataModule是一个可共享、可重用的类,其中封装了处理数据所需要的所有步骤。包括:
数据下载/tokenize/process处理
数据清洗/数据存储
数据集Dataset加载
数据转换transform
包装成Dataloader
在传统的Pytorch代码中,数据处理的代码可能分散在各个不同的地方,没有规范的管理,造成重用的困难。DataModule主要是为了解决这一问题,它将数据处理相关的代码进行管理和封装,使得代码共享和重用变得更加简单。
DataModule的引入主要还是为了对数据处理的逻辑进行抽象,这样在进行模型开发的时候,只需要对接DataModule,至于其他的逻辑,就交给DataModule具体实现。可以说DataModule只是对数据处理相关逻辑进行重新组织,使其向外暴露统一规范的接口。
DataModule在Pytorch
Lightning中对应的类是LightningDataModule
。下面是一个DataModule的代码示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 import lightning as Lfrom torch.utils.data import random_split, DataLoaderfrom torchvision.datasets import MNISTfrom torchvision import transformsclass MNISTDataModule (L.LightningDataModule): def __init__ (self, data_dir: str = "./" ): super ().__init__() self.data_dir = data_dir self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307 ,), (0.3081 ,))]) def prepare_data (self ): MNIST(self.data_dir, train=True , download=True ) MNIST(self.data_dir, train=False , download=True ) def setup (self, stage: str ): if stage == "fit" : mnist_full = MNIST(self.data_dir, train=True , transform=self.transform) self.mnist_train, self.mnist_val = random_split( mnist_full, [55000 , 5000 ], generator=torch.Generator().manual_seed(42 ) ) if stage == "test" : self.mnist_test = MNIST(self.data_dir, train=False , transform=self.transform) if stage == "predict" : self.mnist_predict = MNIST(self.data_dir, train=False , transform=self.transform) def train_dataloader (self ): return DataLoader(self.mnist_train, batch_size=32 ) def val_dataloader (self ): return DataLoader(self.mnist_val, batch_size=32 ) def test_dataloader (self ): return DataLoader(self.mnist_test, batch_size=32 ) def predict_dataloader (self ): return DataLoader(self.mnist_predict, batch_size=32 )
与LightningModule类似,DataModule需要我们实现如下特定的接口:
prepare_data
:
数据准备,完成数据的下载、保存等操作。该函数由主进程main调用,并且仅在cpu上运行
setup
:
在prepare_data执行完成之后执行。这里主要是存放其他可能在GPU上执行的操作,例如计算类别数目、构建词汇表、进行数据划分、创建datasets、执行数据转换等
tarin_dataloader
: 生成(返回)训练数据的DataLoader
val_dataloader
: 生成(返回)评估数据的DataLoader
test_dataloader
: 生成(返回)测试数据的DataLoader
predict_dataloader
:
生成(返回)预测数据的DataLoader
如果我们定义好了一个DataModule,那么就可以在Trainer中进行使用,如下所示:
1 2 3 4 5 mnist_datamodule = MNISTDataModule(my_path) model = LitAutoEncoder() trainer = Trainer() trainer.fit(model, datamodule=mnist_datamodule)
其余相关API可以参考文档:LightningDataModule
— PyTorch Lightning documentation
Other Features
CallBacks
Reference: Callback
— PyTorch Lightning documentation
Trainer允许用户定义不同的callback函数,这些callback函数会在对应的时机被执行,可以类比AOP面向切面编程,也可以类比事件绑定。通常来说,Callback函数是与某个特定的训练过程无关的,独立于训练逻辑的“非必要”代码。
使用Callback,需要我们继承Callback类。例如在下面的代码中,我们使用一个Callback,在训练开始和结束的时候分别输出不同的提示信息:
1 2 3 4 5 6 7 8 9 10 11 12 from lightning.pytorch.callbacks import Callbackclass MyPrintingCallback (Callback ): def on_train_start (self, trainer, pl_module ): print ("Training is starting" ) def on_train_end (self, trainer, pl_module ): print ("Training is ending" ) trainer = Trainer(callbacks=[MyPrintingCallback()])
Callback函数可以类比LightningModule中的Hook函数,区别在于Hook函数通常是模型训练、测试和评估等抽象逻辑所必需的函数,而Callback函数并不是必须的。对于那些非必需的业务操作,我们就可以将它实现在Callback中,而不会污染LightningModule的代码,保证了各个模块的独立性。
Early Stopping
在模型训练的过程中,可能在中途达到了某些指标,此时我们希望能够让它停止训练。这是一个非常常见的需求,在传统的Pytorch代码中,我们需要自己控制条件判断来达到这个效果,而在PytorchLightning中,我们可以非常方便的做到。
第一种方式是使用LightningModule提供的其中一个Hook
on_train_batch_start()
。该Hook方法会在每个epoch中,每个batch执行训练之前调用。如果该方法返回-1
,那么就会跳过当前这个epoch的训练。因此我们可以在这个Hook方法中做对应early
stopping的判断,做到epoch的跳过。由于每个epoch都会有相同的逻辑,因此后续训练过程中的所有epoch也会重复跳过,最终停止所有的训练逻辑。
另一种方式使用EarlyStopping
CallBack,它能够监控某个指标,当观察到这个指标没有提升,就会停止训练。
要使用该Callbak,我们首先需要在validation_step
即验证过程中记录对应的验证指标,例如val_loss
。之后在Trainer中指定使用EarlyStopping
Callback,在其中指定monitor
和mode
参数,表示需要监控的指标以及以何种方式监控。
1 2 3 4 5 6 7 8 9 10 11 12 from lightning.pytorch.callbacks.early_stopping import EarlyStoppingclass LitModel (LightningModule ): def validation_step (self, batch, batch_idx ): loss = ... self.log("val_loss" , loss) model = LitModel() trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss" , mode="min" )]) trainer.fit(model)
Experiment Track
Reference: Track
and Visualize Experiments — PyTorch Lightning documentation
在模型训练的过程中,指标的记录和跟踪是非常重要的。通过记录指标,我们可以可视化模型的学习过程,以便我们做出相应的调整。利用Pytorch
Lightning,我们能够做到多种指标的监控,包括数值、图像、音频甚至视频等。
首先是最简单的数值指标的监控。这个只需要在LightningModule类内部调用self.log()
即可,在调用的同时指定key和value。如果要一次性记录多个指标,则可以使用self.log_dict()
,其中提供对应的metric字典。
1 2 3 4 5 6 7 8 class LitModel (L.LightningModule): def training_step (self, batch, batch_idx ): value = ... self.log("some_value" , value) values_dict = {"loss" : loss, "acc" : acc, "metric_n" : metric_n} self.log_dict(values_dict)
Pytorch
Lightning还支持集成第三方的实验日志管理器,例如tensorboad,wandb等。例如,如果在运行环境中还安装了tensorboard,上面的指标默认也会呈现在tensorboard中;通过lightning.pytorch.loggers
中提供的WandbLogger ,我们就可以直接调用原始wandb中的API。
要在训练过程中使用这些日志记录器,只需要得到对应的logger对象,然后在Trainer初始化的时候传入即可,之后在其他地方,则可以通过self.logger
来访问
1 2 3 4 5 6 7 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger logger1 = TensorBoardLogger() logger2 = WandbLogger() trainer1 = Trainer(logger=[logger1, logger2]) trainer2 = Trainer(logger=logger2)
1 2 3 4 5 6 7 8 9 class MyModule (LightningModule ): def any_lightning_module_function_or_hook (self ): tensorboard_logger = self.loggers.experiment[0 ] wandb_logger = self.loggers.experiment[1 ] fake_images = torch.Tensor(32 , 3 , 28 , 28 ) tensorboard_logger.add_image("generated_images" , fake_images, 0 ) wandb_logger.add_image("generated_images" , fake_images, 0 )
Project Template
虽然Pytorch
Lightning的目的就是在于能够让深度学习代码复用和管理变得简单,但是如果没有很好的使用规范,同样会陷入管理困难,重用复杂的困境。尤其是Pytorch
Lightning在Pytroch的基础上引入了更多的抽象和封装,这就像一柄双刃剑,如果不能很好地组织和管理,反而会带来更多的复杂度。
这里介绍一个结合了Hydra和Pytorch Lightning的项目框架:ashleve/lightning-hydra-template:
PyTorch Lightning + Hydra. A very user-friendly template for ML
experimentation. ⚡🔥⚡