from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from pytorch_lightning import LightningModule
class My_LitModule(LightningModule):
def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
'''method used to define our model parameters'''
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = Linear(28 * 28, n_layer_1)
self.layer_2 = Linear(n_layer_1, n_layer_2)
self.layer_3 = Linear(n_layer_2, n_classes)
self.loss = CrossEntropyLoss()
# save hyper-parameters to self.hparams (auto-logged by W&B)
self.save_hyperparameters()
'''method used for inference input -> output'''
# (b, 1, 28, 28) -> (b, 1*28*28)
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
# let's do 3 x (linear + relu)
x = F.relu(self.layer_1(x))
x = F.relu(self.layer_2(x))
def training_step(self, batch, batch_idx):
'''needs to return a loss from a single batch'''
_, loss, acc = self._get_preds_loss_accuracy(batch)
self.log('train_loss', loss)
self.log('train_accuracy', acc)
def validation_step(self, batch, batch_idx):
'''used for logging metrics'''
preds, loss, acc = self._get_preds_loss_accuracy(batch)
self.log('val_loss', loss)
self.log('val_accuracy', acc)
def configure_optimizers(self):
'''defines model optimizer'''
return Adam(self.parameters(), lr=self.lr)
def _get_preds_loss_accuracy(self, batch):
'''convenience function since train/valid/test steps are similar'''
preds = torch.argmax(logits, dim=1)
loss = self.loss(logits, y)