ModelCheckpoint

class paddle.callbacks. ModelCheckpoint ( save_freq=1, save_dir=None ) [source]

Model checkpoint callback function to save model weights and optimizer state during training in conjunction with model.fit(). Currently, ModelCheckpoint only supports saving after a fixed number of epochs.

Parameters
  • save_freq (int) – The frequency, in number of epochs, the model checkpoint are saved. Default: 1.

  • save_dir (str|None) – The directory to save checkpoint during training. If None, will not save checkpoint. Default: None.

Examples

>>> import paddle
>>> import paddle.vision.transforms as T
>>> from paddle.vision.datasets import MNIST
>>> from paddle.static import InputSpec

>>> inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
>>> labels = [InputSpec([None, 1], 'int64', 'label')]

>>> transform = T.Compose([
...     T.Transpose(),
...     T.Normalize([127.5], [127.5])
... ])
>>> train_dataset = MNIST(mode='train', transform=transform)

>>> lenet = paddle.vision.models.LeNet()
>>> model = paddle.Model(lenet,
...     inputs, labels)

>>> optim = paddle.optimizer.Adam(0.001, parameters=lenet.parameters())
>>> model.prepare(optimizer=optim,
...             loss=paddle.nn.CrossEntropyLoss(),
...             metrics=paddle.metric.Accuracy())

>>> callback = paddle.callbacks.ModelCheckpoint(save_dir='./temp')
>>> model.fit(train_dataset, batch_size=64, callbacks=callback)
on_epoch_begin ( epoch=None, logs=None )

on_epoch_begin

Called at the beginning of each epoch.

Parameters
  • epoch (int) – The index of epoch.

  • logs (dict) – The logs is a dict or None. The logs passed by paddle.Model is None.

on_epoch_end ( epoch, logs=None )

on_epoch_end

Called at the end of each epoch.

Parameters
  • epoch (int) – The index of epoch.

  • logs (dict) – The logs is a dict or None. The logs passed by paddle.Model is a dict, contains ‘loss’, metrics and ‘batch_size’ of last batch.

on_train_end ( logs=None )

on_train_end

Called at the end of training.

Parameters

logs (dict) – The logs is a dict or None. The keys of logs passed by paddle.Model contains ‘loss’, metric names and batch_size.