paddle.amp¶
paddle.amp 目录下包含飞桨框架支持的动态图自动混合精度(AMP)相关的 API。具体如下:
paddle.amp 目录下包含 debugging 目录, debugging 目录中存放用于算子模型精度问题定位的 API。具体如下:
AMP 相关 API¶
API 名称 |
API 功能 |
---|---|
|
创建 AMP 上下文环境 |
|
根据选定混合精度训练模式,改写神经网络参数数据类型 |
|
控制 loss 的缩放比例 |
开启 AMP 后默认转化为 float16 计算的相关 OP¶
OP 名称 |
OP 功能 |
---|---|
conv2d |
卷积计算 |
matmul |
矩阵乘法 |
matmul_v2 |
矩阵乘法 |
mul |
矩阵乘法 |
开启 AMP 后默认使用 float32 计算的相关 OP¶
OP 名称 |
OP 功能 |
---|---|
exp |
指数运算 |
square |
平方运算 |
log |
对数运算 |
mean |
取平均值 |
sum |
求和运算 |
cos_sim |
余弦相似度 |
softmax |
softmax 操作 |
softmax_with_cross_entropy |
softmax 交叉熵损失函数 |
sigmoid_cross_entropy_with_logits |
按元素的概率误差 |
cross_entropy |
交叉熵 |
cross_entropy2 |
交叉熵 |
AMP 场景下判断设备是否支持特定数据类型¶
API 名称 |
API 功能 |
---|---|
|
判断设备是否支持 bfloat16 |
|
判断设备是否支持 float16 |
Debug 相关辅助类¶
类名称 |
辅助类功能 |
---|---|
|
精度调试模式 |
|
精度调试配置类 |
算子调用统计相关 API¶
API 名称 |
API 功能 |
---|---|
|
收集不同数据类型的算子调用次数 |
|
启用以收集不同数据类型的算子调用次数 |
|
禁用收集不同数据类型的算子调用次数 |
模块级别精度定位 API¶
API 名称 |
API 功能 |
---|---|
|
开启模块级别的精度检查 |
|
关闭模块级别的精度检查 |
|
精度比对接口 |
数值检查相关 API¶
API 名称 |
API 功能 |
---|---|
|
Layer 输入、输出数据的数值检查 |
|
调试 Tensor 数值,检查其异常值(NaN、Inf) 和零元素 |