ModelCheckpoint¶
- class paddle.callbacks. ModelCheckpoint ( save_freq=1, save_dir=None ) [source]
-
Model checkpoint callback function to save model weights and optimizer state during training in conjunction with model.fit(). Currently, ModelCheckpoint only supports saving after a fixed number of epochs.
- Parameters
-
save_freq (int) – The frequency, in number of epochs, the model checkpoint are saved. Default: 1.
save_dir (str|None) – The directory to save checkpoint during training. If None, will not save checkpoint. Default: None.
Examples
>>> import paddle >>> import paddle.vision.transforms as T >>> from paddle.vision.datasets import MNIST >>> from paddle.static import InputSpec >>> inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')] >>> labels = [InputSpec([None, 1], 'int64', 'label')] >>> transform = T.Compose([ ... T.Transpose(), ... T.Normalize([127.5], [127.5]) ... ]) >>> train_dataset = MNIST(mode='train', transform=transform) >>> lenet = paddle.vision.models.LeNet() >>> model = paddle.Model(lenet, ... inputs, labels) >>> optim = paddle.optimizer.Adam(0.001, parameters=lenet.parameters()) >>> model.prepare(optimizer=optim, ... loss=paddle.nn.CrossEntropyLoss(), ... metrics=paddle.metric.Accuracy()) >>> callback = paddle.callbacks.ModelCheckpoint(save_dir='./temp') >>> model.fit(train_dataset, batch_size=64, callbacks=callback)
-
on_epoch_begin
(
epoch=None,
logs=None
)
on_epoch_begin¶
-
Called at the beginning of each epoch.
- Parameters
-
epoch (int) – The index of epoch.
logs (dict) – The logs is a dict or None. The logs passed by paddle.Model is None.
-
on_epoch_end
(
epoch,
logs=None
)
on_epoch_end¶
-
Called at the end of each epoch.
- Parameters
-
epoch (int) – The index of epoch.
logs (dict) – The logs is a dict or None. The logs passed by paddle.Model is a dict, contains ‘loss’, metrics and ‘batch_size’ of last batch.
-
on_train_end
(
logs=None
)
on_train_end¶
-
Called at the end of training.
- Parameters
-
logs (dict) – The logs is a dict or None. The keys of logs passed by paddle.Model contains ‘loss’, metric names and batch_size.