メインコンテンツまでスキップ

PyTorch Lightning

Open In Colab

PyTorch Lightningは、PyTorchコードを整理し、分散トレーニングや16ビット精度などの高度な機能を簡単に追加するための軽量なラッパーを提供します。W&Bは、ML実験のログを取るための軽量なラッパーを提供します。しかし、両方を自分で組み合わせる必要はありません。Weights & Biasesは、PyTorch LightningライブラリにWandbLoggerを介して直接組み込まれています。

⚡ たった二行で素早く始めましょう。

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

wandb_logger = WandbLogger()
trainer = Trainer(logger=wandb_logger)

どこからでもアクセス可能なインタラクティブなダッシュボードなど!

wandbにサインアップし、ログインする

a) 無料アカウントにサインアップする

b) wandbライブラリをPipインストール

c) トレーニングスクリプトでログインするには、www.wandb.aiでアカウントにサインインしてから、[**Authorizeページ**](https://wandb.ai/authorize)で**APIキーを見つけてください。**

もし、Weights and Biasesを初めて使う場合は、クイックスタートをチェックしてみてください。

pip install wandb

wandb login

PyTorch Lightning の WandbLogger を使う方法

PyTorch Lightning には、メトリクスやモデルの重み、メディアなどをシームレスにログに記録できる WandbLogger クラスがあります。WandbLogger をインスタンス化し、Lightning の Trainer に渡すだけです。

wandb_logger = WandbLogger()
trainer = Trainer(logger=wandb_logger)

Logger 引数

以下は、WandbLoggerでよく使われるパラメータの一部です。完全なリストと説明は、PyTorch LightningのWandbLoggerドキュメントをご覧ください。

パラメータ説明
projectwandb プロジェクトにログを送る
namewandb runに名前を付ける
log_modellog_model="all" ですべてのモデルをログするか、log_model=True でトレーニング終了時にログする
save_dirデータが保存されるパス

LightningModule のハイパーパラメーターをログに記録する

class LitModule(LightningModule):
def __init__(self, *args, **kwarg):
self.save_hyperparameters()

さらなるconfigパラメータをログに記録する

# 1つのパラメーターを追加する
wandb_logger.experiment.config["key"] = value
# 複数のパラメータを追加する
wandb_logger.experiment.config.update({key1: val1, key2: val2})

# wandbモジュールを直接使う
wandb.config["キー"] =
wandb.config.update()

勾配、パラメータヒストグラム、モデルトポロジーを記録する

モデルオブジェクトを wandblogger.watch() に渡すことで、トレーニング中のモデルの勾配やパラメータを監視できます。詳細は PyTorch Lightning の WandbLogger ドキュメント をご覧ください。

メトリクスを記録する

WandbLogger を使っている場合、LightningModule 内の self.log('my_metric_name', metric_value) を呼び出すことで、W&Bにメトリクスを記録できます。これはあなたの training_step や __validation_step メソッド内で行うことができます。

以下のコードスニペットは、メトリクスと LightningModule のハイパーパラメータを記録するように LightningModule を定義する方法を示しています。この例では、torchmetrics ライブラリを使ってメトリクスを計算します。

import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from pytorch_lightning import LightningModule

class My_LitModule(LightningModule):

def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
'''モデルのパラメータを定義するメソッド'''
super().__init__()

# mnist画像は(1, 28, 28) (チャンネル, 幅, 高さ)です。
self.layer_1 = Linear(28 * 28, n_layer_1)
self.layer_2 = Linear(n_layer_1, n_layer_2)
self.layer_3 = Linear(n_layer_2, n_classes)

self.loss = CrossEntropyLoss()
self.lr = lr

# ハイパーパラメーターをself.hparamsに保存する (W&Bによって自動ロギングされます)
self.save_hyperparameters()

def forward(self, x):
'''推論用の入力から出力へのメソッド'''

# (b, 1, 28, 28) -> (b, 1*28*28)
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)

# 3回(linear + relu)をやりましょう
x = F.relu(self.layer_1(x))
x = F.relu(self.layer_2(x))
x = self.layer_3(x)
return x

def training_step(self, batch, batch_idx):
'''単一バッチからの損失を返す必要があります'''
_, loss, acc = self._get_preds_loss_accuracy(batch)

# 損失とメトリックを記録
self.log('train_loss', loss)
self.log('train_accuracy', acc)
return loss
def validation_step(self, batch, batch_idx):
'''メトリクスのログ記録に使用されます'''
preds, loss, acc = self._get_preds_loss_accuracy(batch)

# 損失とメトリクスをログに記録
self.log('val_loss', loss)
self.log('val_accuracy', acc)
return preds

def configure_optimizers(self):
'''モデルのオプティマイザを定義します'''
return Adam(self.parameters(), lr=self.lr)

def _get_preds_loss_accuracy(self, batch):
'''トレーニング/検証/テストのステップが似ているため、便利な関数です'''
x, y = batch
logits = self(x)
preds = torch.argmax(logits, dim=1)
loss = self.loss(logits, y)
acc = accuracy(preds, y)
return preds, loss, acc

メトリックの最小値/最大値をログに記録

wandbのdefine_metric関数を使用することで、W&Bのサマリーメトリックに表示されるメトリックの最小値、最大値、平均値、または最適値を定義できます。define_metricが使われていない場合、ログに記録された最後の値がサマリーメトリックに表示されます。define_metricのリファレンスドキュメントとガイドを参照してください。

W&Bのサマリーメトリックで最大検証精度を追跡するように指示するには、wandb.define_metricを1回呼び出すだけです。例えば、トレーニングの開始時に以下のように呼び出すことができます。

class My_LitModule(LightningModule):
...

def validation_step(self, batch, batch_idx):
if trainer.global_step == 0:
wandb.define_metric('val_accuracy', summary='max')

preds, loss, acc = self._get_preds_loss_accuracy(batch)

# ロスと指標をログする
self.log('val_loss', loss)
self.log('val_accuracy', acc)
return preds

モデルのチェックポイント作成

W&Bにカスタムチェックポイントを設定するには、PyTorch Lightningの ModelCheckpointWandbLogger のlog_model 引数で使用します。

# `val_accuracy`が増加する場合にのみモデルをログする
wandb_logger = WandbLogger(log_model="all")
checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])

最新と最良のエイリアスが自動的に設定されているため、W&Bのアーティファクトから簡単にモデルチェックポイントを取得できます。

# アーティファクトパネルで参照を取得できます
# "VERSION"はバージョン(例:"v2")またはエイリアス("latest"または"best")であることができます
checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION"


# チェックポイントをローカルにダウンロード(すでにキャッシュされていない場合)
run = wandb.init(project="MNIST")
artifact = run.use_artifact(checkpoint_reference, type="model")
artifact_dir = artifact.download()

# チェックポイントの読み込み
model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")

画像、テキスト、その他のログ

WandbLoggerには、メディアをログするためのlog_imagelog_textlog_tableメソッドがあります。

また、Audio、Molecules、Point Clouds、3Dオブジェクトなどの他のメディアタイプをログするために、直接wandb.logtrainer.logger.experiment.logを呼び出すこともできます。

# using tensors, numpy arrays or PIL images
wandb_logger.log_image(key="samples", images=[img1, img2])

# adding captions
wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])

# using file path
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])

# using .log in the trainer
trainer.logger.experiment.log({
"samples": [wandb.Image(img, caption=caption)
for (img, caption) in my_images]
})

LightningのCallbacksシステムを使用して、WandbLoggerを介してWeights & Biasesにログを記録するタイミングを制御できます。この例では、検証画像のサンプルと予測をログに記録しています。

import torch
import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

class LogPredictionSamplesCallback(Callback):

def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
"""検証バッチが終了したときに呼び出されます。"""

# `outputs`は`LightningModule.validation_step`から来ています
# これは、この場合、モデルの予測に対応します

# 1つ目のバッチから20個のサンプル画像の予測をログに記録しましょう
if batch_idx == 0:
n = 20
x, y = batch
images = [img for img in x[:n]]
captions = [f'正解: {y_i} - 予測: {y_pred}'
for y_i, y_pred in zip(y[:n], outputs[:n])]


# オプション1:`WandbLogger.log_image`を使って画像をログに記録する
wandb_logger.log_image(
key='sample_images',
images=images,
caption=captions)


# オプション2:画像と予測をW&Bテーブルとしてログに記録する
columns = ['image', 'ground truth', 'prediction']
data = [[wandb.Image(x_i), y_i, y_pred] f
or x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]
wandb_logger.log_table(
key='sample_table',
columns=columns,
data=data)
...

trainer = pl.Trainer(
...
callbacks=[LogPredictionSamplesCallback()]
)

LightningとW&Bを使って複数のGPUを使用する方法は?

PyTorch Lightningは、DDPインターフェースを介して、複数のGPUをサポートしています。ただし、PyTorch Lightningの設計では、GPUのインスタンス化方法に注意が必要です。

Lightningは、トレーニングループ内の各GPU(またはランク)が、同じ初期条件で正確に同じ方法でインスタンス化されていることを前提としています。ただし、ランク0プロセスのみがwandb.runオブジェクトにアクセスでき、ランクが0でないプロセスの場合はwandb.run = Noneです。このことで、ランクが0でないプロセスが失敗する可能性があります。このような状況は、ランク0プロセスが既にクラッシュしたランク0以外のプロセスに参加していないため、「デッドロック」に陥る可能性があります。

このため、トレーニングコードを設定する方法に注意が必要です。おすすめの方法は、コードがwandb.runオブジェクトから独立しているように設定することです。

class MNISTClassifier(pl.LightningModule):
def __init__(self):
super(MNISTClassifier, self).__init__()

self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 10),
)

self.loss = nn.CrossEntropyLoss()

def forward(self, x):
return self.model(x)

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = self.loss(y_hat, y)
model = MNISTClassifier()
wandb_logger = WandbLogger(project = "<project_name>")
callbacks = [
ModelCheckpoint(
dirpath = "checkpoints",
every_n_train_steps=100,
),
]
trainer = pl.Trainer(
max_epochs = 3,
gpus = 2,
logger = wandb_logger,
strategy="ddp",
callbacks=callbacks
)
trainer.fit(model, train_loader, val_loader)

インタラクティブな例をチェックしよう!

私たちのチュートリアルColab こちらでビデオチュートリアルに沿って進めることができます。

よくある質問

W&BはLightningとどのように統合されていますか?

コアの統合は、Lightning loggers APIをベースにしており、フレームワークに依存しない方法でログコードの多くを記述できます。LoggerLightning Trainerに渡され、APIの豊富なフック・コールバックシステムに基づいてトリガーされます。これにより、研究コードとエンジニアリング、ログコードがうまく分離されます。

追加コードなしで何をログに残しますか?

モデルのチェックポイントをW&Bに保存し、閲覧や今後のrunでの使用のためにダウンロードすることができます。また、GPU使用量やネットワークI/Oなどのシステムメトリクス、ハードウェアやOS情報などの環境情報、コードの状態(gitコミットや差分パッチ、ノートブックの内容やセッション履歴を含む)、および標準出力に出力される内容をすべてキャプチャします。

トレーニング設定でwandb.runを使用する必要がある場合はどうすればいいですか?

自分でアクセスする必要がある変数のスコープを基本的に拡張する必要があります。言い換えれば、すべてのプロセスで初期条件が同じであることを確認することです。


if os.environ.get("LOCAL_RANK", None) is None:

os.environ["WANDB_DIR"] = wandb.run.dir

そして、os.environ["WANDB_DIR"]を使用してモデルチェックポイントのディレクトリを設定できます。この方法で、wandb.run.dirはゼロ以外のランクのプロセスにも使用できます。

Was this page helpful?👍👎