recompute

paddle.distributed.fleet.utils. recompute ( function, *args, **kwargs ) [源代码]

重新计算中间激活函数值来节省显存。

参数

  • function (paddle.nn.Layer) - 模型前向传播的部分连续的层函数组成的序列,它们的中间激活函数值将在前向传播过程中被释放掉来节省显存,并且在反向梯度计算的时候会重新被计算。

  • args (Tensor) - function 的输入。

  • kwargs (Dict) - kwargs 只应该包含两类键值对。一类键值是 function 的字典参数,另外一类仅只能包含 preserve_rng_state 和 use_reentrant 两个 key。 preserve_rng_state 的键值对,用来表示是否保存前向的 rng,如果为 True,那么在反向传播的重计算前向时会还原上次前向的 rng 值。默认 preserve_rng_state 为 True。 use_reentrant 的键值对,用来表示 recompute 的实现方式,如果为 True,意味着 recompute 使用 PyLayer 的方式实现的,如果为 False, recompute 内部则使用 hook 的方式实现的,默认值是 True。在某些场景下,比如 recompute 与数据并行结合时,需要额外调用 no_sync 函数,此时可以设置 use_reentrant=False,选用 hook 方式的 recompute,可以避免额外调用 no_sync 函数。

返回

function 作用在输入的输出

代码示例

import paddle
from paddle.distributed.fleet.utils import recompute
import random
# required: gpu
def get_fc_block(block_idx, input_size, is_last=False):
    block_name = "block_" + str(block_idx)
    block = paddle.nn.Sequential(
        (block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
        (block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
        (block_name + "_relu_1", paddle.nn.ReLU()),
        (block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)),
        (block_name + "_relu_2", paddle.nn.ReLU()),
    )
    if is_last:
        block.add_sublayer(
            block_name + "_fc_2",
            paddle.nn.Linear(
                input_size, 1, bias_attr=False
            )
        )
    else:
        block.add_sublayer(
            block_name + "_fc_2",
            paddle.nn.Linear(input_size, input_size, bias_attr=False)
        )
    return block
class Naive_fc_net(paddle.nn.Layer):
    def __init__(self, input_size=10,
                recompute_blocks=[1, 3],
                recompute_kwargs={}):
        super().__init__()
        self.recompute_blocks = recompute_blocks
        self.recompute_kwargs = recompute_kwargs
        self.runfunc0 = get_fc_block(0, input_size, is_last=False)
        self.runfunc1 = get_fc_block(1, input_size, is_last=False)
        self.runfunc2 = get_fc_block(2, input_size, is_last=False)
        self.runfunc3 = get_fc_block(3, input_size, is_last=False)
        self.runfunc4 = get_fc_block(4, input_size, is_last=True)
        self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4]
    def forward(self, inputs):
        nums = len(self.total_func)
        for i in range(nums):
            if i in self.recompute_blocks:
                inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True})
            else:
                inputs = self.total_func[i](inputs)
        return inputs
def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
    gen = paddle.seed(10)
    gen.manual_seed(10)
    random.seed(10)
    if cuda_state:
        paddle.set_cuda_rng_state(cuda_state)
    batch_size, input_size = 1, 10
    model = Naive_fc_net(
        input_size,
        recompute_blocks=recompute_block,
        recompute_kwargs=recompute_kwargs)
    optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
    loss_ = []
    param_ = []
    grad_ = []
    for _ in range(5):
        x = paddle.rand(shape=[batch_size, input_size], dtype="float32")
        y_pred = model(x)
        loss = y_pred.mean()
        loss_.append(loss.item())
        loss.backward()
        optimizer.step()
        param_.append(model.parameters()[9])
        grad_.append(model.parameters()[3]._grad_ivar())
        optimizer.clear_grad()
    return loss_, param_, grad_
cuda_state = paddle.get_cuda_rng_state()
# without recompute
loss_ref, param_ref, grad_ref = run_model(
    cuda_state, recompute_block=[]
)
loss, param, grad = run_model(cuda_state, recompute_block=[1, 2])
print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss))
# The result of the recompute_loss should be the same as the normal_loss.

使用本API的教程文档