自定义Loss、Metric 及 Callback¶
除了使用飞桨框架内置的 API,有时会需要根据实际场景,自定义 Loss、Metric 及 Callback 来使用,本节介绍在飞桨框架中自定义的方法。
一、自定义损失函数 Loss¶
1.1 损失函数介绍¶
损失函数用来评估模型的预测结果与真实结果之间的差距,损失函数越小,模型的鲁棒性就越好。模型训练的过程其实是对损失函数采用梯度下降的方法,使得损失函数不断减小到局部最优值,而得到对任务来说比较合理的模型参数。
一般在深度学习任务中,有许多常用的损失函数,例如在图像分类任务中常用的交叉熵损失函数,在目标检测任务中常用的 Focal loss、L1/L2 损失函数等,在图像识别任务中常用的 Triplet Loss 以及 Center Loss 等。如果框架中提供的损失函数不能满足试图解决的问题,可按照框架的 API 结构要求自定义损失函数。
1.2 自定义 Loss 步骤¶
飞桨框架中实现自定义 Loss 的方法和使用 paddle.nn.Layer 进行模型组网的方法类似,包括三个步骤:
创建一个继承自
paddle.nn.Layer
的类;在类的构造函数
__init__
中定义需要的参数;在类的前向计算函数
forward
中进行损失函数计算。
import paddle
class SelfDefineLoss(paddle.nn.Layer):
"""
1. 继承paddle.nn.Layer
"""
def __init__(self):
"""
2. 构造函数根据自己的实际算法需求和使用需求进行参数定义即可
"""
super().__init__()
def forward(self, x, label):
"""
3. 实现forward函数,forward在调用时会传递两个参数:x和label
- x:单个或批次训练数据经过模型前向计算输出结果
- label:单个或批次训练数据对应的标签数据
接口返回值是一个Tensor,根据需要将所有x和label计算得到的loss值求和或取均值
"""
# 返回forword中计算的结果
# output = xxxxx
# return output
下面是定义交叉熵损失函数 Loss 的示例:
class CrossEntropy(paddle.nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x, label):
# 使用paddle内置的cross_entropy算子实现算法
loss = paddle.nn.functional.cross_entropy(
x,
label)
return loss
二、自定义评估指标 Metric¶
2.1 评估指标介绍¶
评估指标用来衡量一个模型的实际效果好坏,一般是通过计算模型的预测结果和真实结果之间的某种"距离"得出。
和损失函数类似,一般会在不同的任务场景中选择不同的评估指标来做模型评估,例如在分类任务中常见的评估指标包括了 Accuracy、Recall、Precision 和 AUC 等,在回归任务中常用的有 MAE 和 MSE 等等。这些常见的评估指标在飞桨框架中都有对应的 API 实现,可直接使用。如果不能满足需求,则可按照框架的 API 结构要求自定义评估指标。
2.2 自定义评估指标¶
通过框架实现自定义评估指标的方法,包括如下几个步骤:
创建一个继承自 paddle.metric.Metric 的类;
在类的构造函数
__init__
中定义需要的参数;实现
name
方法,返回定义的评估指标名字;实现
compute
方法,这个方法主要用于update
的加速,可省略;实现
update
方法,用于单个 batch 训练时进行评估指标计算;实现
accumulate
方法,返回历史 batch 训练积累后计算得到的评价指标值;实现
reset
方法,每个 epoch 结束后进行评估指标的重置。
class SelfDefineMetric(paddle.metric.Metric):
"""
1. 继承paddle.metric.Metric
"""
def __init__(self):
"""
2. 构造函数实现,自定义参数即可
"""
super(SelfDefineMetric, self).__init__()
def name(self):
"""
3. 实现name方法,返回定义的评估指标名字
"""
# return '自定义评价指标的名字'
def compute(self, **args):
"""
4. 本步骤可以省略,实现compute方法,这个方法主要用于`update`的加速,可以在这个方法中调用一些飞桨框架实现好的Tensor计算API,编译到模型网络中一起使用低层C++ OP计算。
"""
# return '自己想要返回的数据,会做为update的参数传入。'
def update(self, **args):
"""
5. 实现update方法,用于单个batch训练时进行评估指标计算。
- 当`compute`类函数未实现时,会将模型的计算输出和标签数据的展平作为`update`的参数传入。
- 当`compute`类函数做了实现时,会将compute的返回结果作为`update`的参数传入。
"""
# return acc_value
def accumulate(self):
"""
6. 实现accumulate方法,返回历史batch训练积累后计算得到的评价指标值。
每次`update`调用时进行数据积累,`accumulate`计算时对积累的所有数据进行计算并返回。
结算结果会在`fit`接口的训练日志中呈现。
"""
# 利用update中积累的成员变量数据进行计算后返回
# return accumulated_acc_value
def reset(self):
"""
7. 实现reset方法,每个Epoch结束后进行评估指标的重置,这样下个Epoch可以重新进行计算。
"""
# do reset action
接下来看一个框架中的具体例子,Accuracy 评价指标的示例,这里就是按照上述说明中的方法完成了实现。
class Accuracy(paddle.metric.Metric):
"""
继承paddle.metric.Metric
"""
def __init__(self, topk=(1, ), name=None, *args, **kwargs):
"""
构造函数实现
"""
super(Accuracy, self).__init__(*args, **kwargs)
self.topk = topk
self.maxk = max(topk)
self._init_name(name)
self.reset()
def compute(self, pred, label, *args):
"""
实现compute方法
"""
pred = paddle.argsort(pred, descending=True)
pred = paddle.slice(
pred, axes=[len(pred.shape) - 1], starts=[0], ends=[self.maxk])
if (len(label.shape) == 1) or \
(len(label.shape) == 2 and label.shape[-1] == 1):
# In static mode, the real label data shape may be different
# from shape defined by paddle.static.InputSpec in model
# building, reshape to the right shape.
label = paddle.reshape(label, (-1, 1))
elif label.shape[-1] != 1:
# one-hot label
label = paddle.argmax(label, axis=-1, keepdim=True)
correct = pred == label
return paddle.cast(correct, dtype='float32')
def update(self, correct, *args):
"""
实现update方法,用于单个batch训练时进行评估指标计算。
- 当`compute`类函数未实现时,会将模型的计算输出和标签数据的展平作为`update`的参数传入。
- 当`compute`类函数做了实现时,会将compute的返回结果作为`update`的参数传入。
"""
if isinstance(correct, paddle.Tensor):
correct = correct.numpy()
num_samples = np.prod(np.array(correct.shape[:-1]))
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[..., :k].sum()
accs.append(float(num_corrects) / num_samples)
self.total[i] += num_corrects
self.count[i] += num_samples
accs = accs[0] if len(self.topk) == 1 else accs
return accs
def reset(self):
"""
实现reset方法,每个Epoch结束后进行评估指标的重置,这样下个Epoch可以重新进行计算。
"""
self.total = [0.] * len(self.topk)
self.count = [0] * len(self.topk)
def accumulate(self):
"""
实现accumulate方法,返回历史batch训练积累后计算得到的评价指标值。
每次`update`调用时进行数据积累,`accumulate`计算时对积累的所有数据进行计算并返回。
结算结果会在`fit`接口的训练日志中呈现。
"""
res = []
for t, c in zip(self.total, self.count):
r = float(t) / c if c > 0 else 0.
res.append(r)
res = res[0] if len(self.topk) == 1 else res
return res
def _init_name(self, name):
name = name or 'acc'
if self.maxk != 1:
self._name = ['{}_top{}'.format(name, k) for k in self.topk]
else:
self._name = [name]
def name(self):
"""
实现name方法,返回定义的评估指标名字
"""
return self._name
三、自定义回调函数 Callback¶
3.1 回调函数介绍¶
Callback 回调函数常用于对模型训练、评估、推理过程状态和参数的观察,如训练进度、loss 值等信息;也可用于实现一些自定义操作,如设置当 loss 值达到一定阈值时停止训练、按照设定规则定期保存模型等。可方便地掌握模型训练状态,及时做出灵活调整。
Callback 用在 Model.fit
、Model.evaluate
、Model.predict
等飞桨高层 API 中,先定义一个继承自 paddle.callbacks.Callback 的类,然后通过高层 API 接口的 callback 参数传入类的实例,用于模型训练、评估或推理过程中调用。
3.2 自定义回调函数¶
自定义回调函数的实现模板如下所示:
class SelfDefineCallback(paddle.callbacks.Callback):
"""
1. 继承paddle.callbacks.Callback
2. 按照自己的需求实现以下类成员方法:
def on_train_begin(self, logs=None) 训练开始前,`Model.fit`接口中调用
def on_train_end(self, logs=None) 训练结束后,`Model.fit`接口中调用
def on_eval_begin(self, logs=None) 评估开始前,`Model.evaluate`接口调用
def on_eval_end(self, logs=None) 评估结束后,`Model.evaluate`接口调用
def on_predict_begin(self, logs=None) 推理开始前,`Model.predict`接口中调用
def on_predict_end(self, logs=None) 推理结束后,`Model.predict`接口中调用
def on_epoch_begin(self, epoch, logs=None) 每轮训练开始前,`Model.fit`接口中调用
def on_epoch_end(self, epoch, logs=None) 每轮训练结束后,`Model.fit`接口中调用
def on_train_batch_begin(self, step, logs=None) 单个Batch训练开始前,`Model.fit`和`Model.train_batch`接口中调用
def on_train_batch_end(self, step, logs=None) 单个Batch训练结束后,`Model.fit`和`Model.train_batch`接口中调用
def on_eval_batch_begin(self, step, logs=None) 单个Batch评估开始前,`Model.evalute`和`Model.eval_batch`接口中调用
def on_eval_batch_end(self, step, logs=None) 单个Batch评估结束后,`Model.evalute`和`Model.eval_batch`接口中调用
def on_predict_batch_begin(self, step, logs=None) 单个Batch推理开始前,`Model.predict`和`Model.test_batch`接口中调用
def on_predict_batch_end(self, step, logs=None) 单个Batch推理结束后,`Model.predict`和`Model.test_batch`接口中调用
"""
def __init__(self):
super().__init__()
# 按照需求定义自己的类成员方法
飞桨框架在 paddle.callbacks 下内置了一些常用的回调函数相关 API,接下来看两个框架中的实际例子。其中第一个例子时框架自带的 ModelCheckpoint
回调函数,可以在 Model.fit
训练模型时自动存储每轮训练得到的模型;第二个例子是框架自带的 ProgBarLogger
回调函数,用于在 Model.fit
训练时打印损失函数和评估指标。这两个回调函数会在 Model.fit
执行时默认被调用。
class ModelCheckpoint(paddle.callbacks.Callback):
"""
继承paddle.callbacks.Callback,该类的功能是
训练模型时自动存储每轮训练得到的模型
"""
def __init__(self, save_freq=1, save_dir=None):
"""
构造函数实现
"""
self.save_freq = save_freq
self.save_dir = save_dir
def on_epoch_begin(self, epoch=None, logs=None):
"""
每轮训练开始前,获取当前轮数
"""
self.epoch = epoch
def _is_save(self):
return self.model and self.save_dir and ParallelEnv().local_rank == 0
def on_epoch_end(self, epoch, logs=None):
"""
每轮训练结束后,保存每轮的checkpoint
"""
if self._is_save() and self.epoch % self.save_freq == 0:
path = '{}/{}'.format(self.save_dir, epoch)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)
def on_train_end(self, logs=None):
"""
训练结束后,保存最后一轮的checkpoint
"""
if self._is_save():
path = '{}/final'.format(self.save_dir)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)
import time
from paddle.distributed import ParallelEnv
from paddle.utils import try_import
from paddle.hapi.progressbar import ProgressBar
class ProgBarLogger(paddle.callbacks.Callback):
"""
继承paddle.callbacks.Callback,该类的功能是
训练模型时打印损失函数和评估指标
"""
def __init__(self, log_freq=1, verbose=2):
"""
构造函数实现
"""
self.epochs = None
self.steps = None
self.progbar = None
self.verbose = verbose
self.log_freq = log_freq
def _is_print(self):
return self.verbose and ParallelEnv().local_rank == 0
def on_train_begin(self, logs=None):
"""
训练开始前,获取总epoch、metric等信息
"""
self.epochs = self.params['epochs']
assert self.epochs
self.train_metrics = self.params['metrics']
assert self.train_metrics
self._train_timer = {
'data_time': 0,
'batch_time': 0,
'count': 0,
'samples': 0,
}
if self._is_print():
print(
"The loss value printed in the log is the current step, and the metric is the average value of previous steps."
)
def on_epoch_begin(self, epoch=None, logs=None):
"""
每轮训练开始前,获取当前轮数、步数,声明进度条与计时器等
"""
self.steps = self.params['steps']
self.epoch = epoch
self.train_step = 0
if self.epochs and self._is_print():
print('Epoch %d/%d' % (epoch + 1, self.epochs))
self.train_progbar = ProgressBar(num=self.steps, verbose=self.verbose)
self._train_timer['batch_start_time'] = time.time()
def _updates(self, logs, mode):
values = []
metrics = getattr(self, '%s_metrics' % (mode))
progbar = getattr(self, '%s_progbar' % (mode))
steps = getattr(self, '%s_step' % (mode))
for k in metrics:
if k in logs:
values.append((k, logs[k]))
if self.verbose == 3 and hasattr(self, '_%s_timer' % (mode)):
timer = getattr(self, '_%s_timer' % (mode))
cnt = timer['count'] if timer['count'] > 0 else 1.0
samples = timer['samples'] if timer['samples'] > 0 else 1.0
values.append(
('avg_reader_cost', "%.5f sec" % (timer['data_time'] / cnt)))
values.append(
('avg_batch_cost', "%.5f sec" % (timer['batch_time'] / cnt)))
values.append(
('ips', "%.5f samples/sec" %
(samples / (timer['data_time'] + timer['batch_time']))))
timer['count'] = 0
timer['samples'] = 0
timer['data_time'] = 0.
timer['batch_time'] = 0.
progbar.update(steps, values)
def on_train_batch_begin(self, step, logs=None):
"""
单个Batch训练开始前,进行计时
"""
self._train_timer['batch_data_end_time'] = time.time()
self._train_timer['data_time'] += (
self._train_timer['batch_data_end_time'] -
self._train_timer['batch_start_time'])
def on_train_batch_end(self, step, logs=None):
"""
单个Batch训练结束后,更新参数
"""
logs = logs or {}
self.train_step += 1
self._train_timer['batch_time'] += (
time.time() - self._train_timer['batch_data_end_time'])
self._train_timer['count'] += 1
samples = logs.get('batch_size', 1)
self._train_timer['samples'] += samples
if self._is_print() and self.train_step % self.log_freq == 0:
if self.steps is None or self.train_step < self.steps:
self._updates(logs, 'train')
self._train_timer['batch_start_time'] = time.time()
def on_epoch_end(self, epoch, logs=None):
"""
每轮训练结束后,更新参数
"""
logs = logs or {}
if self._is_print() and (self.steps is not None):
self._updates(logs, 'train')
def on_eval_begin(self, logs=None):
"""
评估开始前,获取当前步数,声明进度条与计时器等
"""
self.eval_steps = logs.get('steps', None)
self.eval_metrics = logs.get('metrics', [])
self.eval_step = 0
self.evaled_samples = 0
self._eval_timer = {
'data_time': 0,
'batch_time': 0,
'count': 0,
'samples': 0,
}
self.eval_progbar = ProgressBar(
num=self.eval_steps, verbose=self.verbose)
if self._is_print():
print('Eval begin...')
self._eval_timer['batch_start_time'] = time.time()
def on_eval_batch_begin(self, step, logs=None):
"""
单个Batch评估开始前,进行计时
"""
self._eval_timer['batch_data_end_time'] = time.time()
self._eval_timer['data_time'] += (
self._eval_timer['batch_data_end_time'] -
self._eval_timer['batch_start_time'])
def on_eval_batch_end(self, step, logs=None):
"""
单个Batch评估结束后,更新参数
"""
logs = logs or {}
self.eval_step += 1
samples = logs.get('batch_size', 1)
self.evaled_samples += samples
self._eval_timer['batch_time'] += (
time.time() - self._eval_timer['batch_data_end_time'])
self._eval_timer['count'] += 1
samples = logs.get('batch_size', 1)
self._eval_timer['samples'] += samples
if self._is_print() and self.eval_step % self.log_freq == 0:
if self.eval_steps is None or self.eval_step < self.eval_steps:
self._updates(logs, 'eval')
self._eval_timer['batch_start_time'] = time.time()
def on_predict_begin(self, logs=None):
"""
推理开始前,获取当前步数,声明进度条与计时器等
"""
self.test_steps = logs.get('steps', None)
self.test_metrics = logs.get('metrics', [])
self.test_step = 0
self.tested_samples = 0
self._test_timer = {
'data_time': 0,
'batch_time': 0,
'count': 0,
'samples': 0,
}
self.test_progbar = ProgressBar(
num=self.test_steps, verbose=self.verbose)
if self._is_print():
print('Predict begin...')
self._test_timer['batch_start_time'] = time.time()
def on_predict_batch_begin(self, step, logs=None):
"""
单个Batch推理开始前,进行计时
"""
self._test_timer['batch_data_end_time'] = time.time()
self._test_timer['data_time'] += (
self._test_timer['batch_data_end_time'] -
self._test_timer['batch_start_time'])
def on_predict_batch_end(self, step, logs=None):
"""
单个Batch推理结束后,更新参数
"""
logs = logs or {}
self.test_step += 1
samples = logs.get('batch_size', 1)
self.tested_samples += samples
self._test_timer['batch_time'] += (
time.time() - self._test_timer['batch_data_end_time'])
self._test_timer['count'] += 1
samples = logs.get('batch_size', 1)
self._test_timer['samples'] += samples
if self.test_step % self.log_freq == 0 and self._is_print():
if self.test_steps is None or self.test_step < self.test_steps:
self._updates(logs, 'test')
self._test_timer['batch_start_time'] = time.time()
def on_eval_end(self, logs=None):
"""
评估结束后,更新参数,打印信息
"""
logs = logs or {}
if self._is_print() and (self.eval_steps is not None):
self._updates(logs, 'eval')
print('Eval samples: %d' % (self.evaled_samples))
def on_predict_end(self, logs=None):
"""
推理结束后,更新参数,打印信息
"""
logs = logs or {}
if self._is_print():
if self.test_step % self.log_freq != 0 or self.verbose == 1:
self._updates(logs, 'test')
print('Predict samples: %d' % (self.tested_samples))
四、自定义Loss、Metric 及 Callback 的使用¶
以下代码示例中,介绍了自定义 Loss、Metric 及 Callback 后,如何在模型训练中使用。自定义的 loss、Metric 可传入 paddle.Model.prepare
中完成训练准备配置,callback 可传入 paddle.Model.fit
中在模型训练中调用。
import paddle
import numpy as np
from paddle.vision.transforms import Normalize
transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
# 加载数据集
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
mnist = paddle.nn.Sequential(
paddle.nn.Flatten(1, -1),
paddle.nn.Linear(784, 512),
paddle.nn.ReLU(),
paddle.nn.Dropout(0.2),
paddle.nn.Linear(512, 10)
)
model = paddle.Model(mnist)
# 将paddle.nn.CrossEntropyLoss替换为CrossEntropy
# 将paddle.metric.Accuracy替换为Accuracy
model.prepare(optimizer=paddle.optimizer.Adam(parameters=model.parameters()),
loss=CrossEntropy(),
metrics=Accuracy())
# 启动模型训练,加入自定义的两个Callbacks
model.fit(train_dataset,
epochs=5,
batch_size=64,
verbose=0,
callbacks=[ProgBarLogger(verbose=1), ModelCheckpoint()]
)
W1223 04:29:17.810079 9910 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.2, Runtime API Version: 10.2
W1223 04:29:17.815956 9910 device_context.cc:465] device: 0, cuDNN Version: 7.6.
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/5
step 2/938 [..............................] - loss: 2.1375 - acc: 0.1562 - ETA: 7:27 - 478ms/step
/usr/local/python3.7.0/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return (isinstance(seq, collections.Sequence) and
step 938/938 [==============================] - loss: 0.1039 - acc: 0.9303 - 138ms/step
Epoch 2/5
step 938/938 [==============================] - loss: 0.0887 - acc: 0.9696 - 123ms/step
Epoch 3/5
step 938/938 [==============================] - loss: 0.0285 - acc: 0.9782 - 158ms/step
Epoch 4/5
step 938/938 [==============================] - loss: 0.0049 - acc: 0.9833 - 158ms/step
Epoch 5/5
step 938/938 [==============================] - loss: 0.1041 - acc: 0.9863 - 145ms/step
五、总结¶
本节中介绍了飞桨框架中一些高阶自定义用法,包括自定义 Loss、Metric 及 Callback。飞桨框架既内置了丰富的组件,方便用户直接使用提升模型开发效率,也提供开放的接口方便用户根据任务需求自定义组件来使用,以便更灵活地进行模型开发。