DataParallel

class paddle. DataParallel ( layers, strategy=None, comm_buffer_size=25, last_comm_buffer_size=1, find_unused_parameters=False ) [source]

Run the dygraph module with data parallelism.

Currently, DataParallel class only supports to run the dynamic graph with multi-process.

Now supports two ways to start training:

  1. start by paddle.distributed.spawn method, for example:

    python demo.py (spawn need to be called in __main__ method)

  2. start by paddle.distributed.launch module, for example:

    python -m paddle.distributed.launch --gpus=0,1 demo.py .

And the content of demo.py is the code of examples.

Parameters
  • layers (Layer) – The module that should be executed by data parallel.

  • strategy (ParallelStrategy, optional) – (deprecated) The strategy of data parallelism, contains environment configuration related to parallel execution. Default: None.

  • comm_buffer_size (int, optional) – It limits the memory size(MB) of one buffer parameters’ gradient which is the input of communication calling(e.g NCCLAllReduce). Default: 25.

  • last_comm_buffer_size (float, optional) – It limits memory size(MB) of last buffer in communication calling. Making the last communication buffer size small is useful to improve performance. Default: 1.

  • find_unused_parameters (bool, optional) – Whether to traverse the entire backward graph from the all tensors in the return value of the wrapped model’s forward function. For parameters not involved in loss calculation, their gradients will be marked as ready in advance to prepare reduce. Please note that all forward outputs derived from the wrapped model parameters must participate in the calculation of loss and subsequent gradient calculations. If not, serious error will occur. Note that setting the find_unused_parameters to True will affect computing performance. Therefore, if all parameters are sure to participate in the loss calculation and the autograd graph construction, please set it False. Default: False.

Returns

The data paralleled module.

Return type

Layer

Examples

# required: distributed
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist

class LinearNet(nn.Layer):
    def __init__(self):
        super(LinearNet, self).__init__()
        self._linear1 = nn.Linear(10, 10)
        self._linear2 = nn.Linear(10, 1)

    def forward(self, x):
        return self._linear2(self._linear1(x))

def train():
    # 1. initialize parallel environment
    dist.init_parallel_env()

    # 2. create data parallel layer & optimizer
    layer = LinearNet()
    dp_layer = paddle.DataParallel(layer)

    loss_fn = nn.MSELoss()
    adam = opt.Adam(
        learning_rate=0.001, parameters=dp_layer.parameters())

    # 3. run layer
    inputs = paddle.randn([10, 10], 'float32')
    outputs = dp_layer(inputs)
    labels = paddle.randn([10, 1], 'float32')
    loss = loss_fn(outputs, labels)

    loss.backward()

    adam.step()
    adam.clear_grad()

if __name__ == '__main__':
    # 1. start by ``paddle.distributed.spawn`` (default)
    dist.spawn(train, nprocs=2)
    # 2. start by ``paddle.distributed.launch``
    # train()

Note

PyLayer is not supported in DataParallel. To solve problems of this kind, it’s recommended to skip gradient synchronization among multiple cards by ‘no_sync’, and manually implement ‘all_reduce’ before model optimization. There is an example showing specific implemetation processing.

Examples

# required: distributed
import numpy
import paddle
import paddle.distributed as dist
from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients

class cus_tanh(PyLayer):
    @staticmethod
    def forward(ctx, x):
        y = paddle.tanh(x)
        ctx.save_for_backward(y)
        return y

    @staticmethod
    def backward(ctx, dy):
        y, = ctx.saved_tensor()
        grad = dy * (1 - paddle.square(y))
        return grad

class SimpleNet(paddle.nn.Layer):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.linear = paddle.nn.Linear(2, 2)

    def forward(self, inputs):
        inputs = cus_tanh.apply(inputs)
        return self.linear(inputs)

if __name__ == '__main__':
    dist.init_parallel_env()

    model = SimpleNet()
    model = paddle.DataParallel(model)
    opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

    for step in range(10):
        x_data = numpy.random.randn(2, 2).astype(numpy.float32)
        x = paddle.to_tensor(x_data)
        x.stop_gradient = False

        # step 1 : skip gradient synchronization by 'no_sync'
        with model.no_sync():
            y_pred = model(x)
            loss = y_pred.mean()
            loss.backward()

        # step 2 : fuse + allreduce manually before optimization
        fused_allreduce_gradients(list(model.parameters()), None)

        opt.step()
        opt.clear_grad()
no_sync ( )

no_sync

A context manager to stop gradient synchronization. Within no_sync(), gradients of parameters will only be accumulated on model and not synchronized util the first forward-backward out of this context.

Examples

# required: distributed
import paddle
import paddle.nn as nn
import paddle.distributed as dist

class SimpleNet(nn.Layer):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self._linear = nn.Linear(10, 1)

    def forward(self, x):
        return self._linear(x)

dist.init_parallel_env()
model = SimpleNet()
dp_model = paddle.DataParallel(model)

inputs_1 = paddle.randn([10, 10], 'float32')
inputs_2 = paddle.ones([10, 10], 'float32')

with dp_model.no_sync():
    # gradients will not be synchronized
    dp_model(inputs_1).backward()

# synchronization happens here
dp_model(inputs_2).backward()
forward ( *inputs, **kwargs )

forward

Defines the computation performed at every call. Should be overridden by all subclasses.

Parameters
  • *inputs (tuple) – unpacked tuple arguments

  • **kwargs (dict) – unpacked dict arguments

scale_loss ( loss )

scale_loss

Deprecated method, now scale_loss is an empty method, keep this method just for compatibility.

apply_collective_grads ( )

apply_collective_grads

Deprecated method, now apply_collective_grads is an empty method, keep this method just for compatibility.

state_dict ( destination=None, include_sublayers=True, structured_name_prefix='' )

state_dict

Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict

Parameters
  • destination (dict, optional) – If provide, all the parameters and persistable buffers will be set to this dict . Default: None

  • include_sublayers (bool, optional) – If true, also include the parameters and persistable buffers from sublayers. Default: True

Retruns:

dict: a dict contains all the parameters and persistable buffers.

Examples

import paddle
import paddle.distributed as dist

dist.init_parallel_env()

emb = fluid.dygraph.Embedding([10, 10])
emb = fluid.dygraph.DataParallel(emb)

state_dict = emb.state_dict()
paddle.save(state_dict, "paddle_dy.pdparams")
add_parameter ( name, parameter )

add_parameter

Adds a Parameter instance.

Added parameter can be accessed by self.name

Parameters
  • name (str) – name of this sublayer.

  • parameter (Parameter) – an instance of Parameter.

Returns

the parameter passed in.

Return type

Parameter

Examples

import paddle

class MyLayer(paddle.nn.Layer):
    def __init__(self):
        super(MyLayer, self).__init__()
        self._linear = paddle.nn.Linear(1, 1)
        w_tmp = self.create_parameter([1,1])
        self.add_parameter("w_tmp", w_tmp)

    def forward(self, input):
        return self._linear(input)

mylayer = MyLayer()
for name, param in mylayer.named_parameters():
    print(name, param)      # will print w_tmp,_linear.weight,_linear.bias
add_sublayer ( name, sublayer )

add_sublayer

Adds a sub Layer instance.

Added sublayer can be accessed by self.name

Parameters
  • name (str) – name of this sublayer.

  • sublayer (Layer) – an instance of Layer.

Returns

the sublayer passed in.

Return type

Layer

Examples

import paddle

class MySequential(paddle.nn.Layer):
    def __init__(self, *layers):
        super(MySequential, self).__init__()
        if len(layers) > 0 and isinstance(layers[0], tuple):
            for name, layer in layers:
                self.add_sublayer(name, layer)
        else:
            for idx, layer in enumerate(layers):
                self.add_sublayer(str(idx), layer)

    def forward(self, input):
        for layer in self._sub_layers.values():
            input = layer(input)
        return input

fc1 = paddle.nn.Linear(10, 3)
fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
model = MySequential(fc1, fc2)
for prefix, layer in model.named_sublayers():
    print(prefix, layer)
apply ( fn )

apply

Applies fn recursively to every sublayer (as returned by .sublayers()) as well as self. Typical use includes initializing the parameters of a model.

Parameters

fn (function) – a function to be applied to each sublayer

Returns

self

Return type

Layer

Example::
import paddle
import paddle.nn as nn

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

def init_weights(layer):
    if type(layer) == nn.Linear:
        print('before init weight:', layer.weight.numpy())
        new_weight = paddle.full(shape=layer.weight.shape, dtype=layer.weight.dtype, fill_value=0.9)
        layer.weight.set_value(new_weight)
        print('after init weight:', layer.weight.numpy())

net.apply(init_weights)

print(net.state_dict())
buffers ( include_sublayers=True )

buffers

Returns a list of all buffers from current layer and its sub-layers.

Parameters

include_sublayers (bool, optional) – Whether include the buffers of sublayers. If True, also include the buffers from sublayers. Default: True

Returns

a list of buffers.

Return type

list of Tensor

Examples

import numpy as np
import paddle

linear = paddle.nn.Linear(10, 3)
value = np.array([0]).astype("float32")
buffer = paddle.to_tensor(value)
linear.register_buffer("buf_name", buffer, persistable=True)

print(linear.buffers())     # == print([linear.buf_name])
children ( )

children

Returns an iterator over immediate children layers.

Yields

Layer – a child layer

Examples

import paddle

linear1 = paddle.nn.Linear(10, 3)
linear2 = paddle.nn.Linear(3, 10, bias_attr=False)
model = paddle.nn.Sequential(linear1, linear2)

layer_list = list(model.children())

print(layer_list)   # [<paddle.nn.layer.common.Linear object at 0x7f7b8113f830>, <paddle.nn.layer.common.Linear object at 0x7f7b8113f950>]
clear_gradients ( )

clear_gradients

Clear the gradients of all parameters for this layer.

Returns

None

Examples

import paddle
import numpy as np

value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.Adam(learning_rate=0.01,
                            parameters=linear.parameters())
out = linear(a)
out.backward()
adam.step()
linear.clear_gradients()
create_parameter ( shape, attr=None, dtype=None, is_bias=False, default_initializer=None ) [source]

create_parameter

Create parameters for this layer.

Parameters
  • shape (list) – Shape of the parameter.

  • attr (ParamAttr, optional) – Parameter attribute of weight. Please refer to ParamAttr. Default: None.

  • dtype (str, optional) – Data type of this parameter. If set str, it can be “bool”, “float16”, “float32”, “float64”, “int8”, “int16”, “int32”, “int64”, “uint8” or “uint16”. Default: “float32”.

  • is_bias (bool, optional) – if this is a bias parameter. Default: False.

  • default_initializer (Initializer, optional) – the default initializer for this parameter. If set None, default initializer will be set to paddle.nn.initializer.Xavier and paddle.nn.initializer.Constant for non-bias and bias parameter, respectively. Default: None.

Returns

Tensor, created parameter.

Examples

import paddle

class MyLayer(paddle.nn.Layer):
    def __init__(self):
        super(MyLayer, self).__init__()
        self._linear = paddle.nn.Linear(1, 1)
        w_tmp = self.create_parameter([1,1])
        self.add_parameter("w_tmp", w_tmp)

    def forward(self, input):
        return self._linear(input)

mylayer = MyLayer()
for name, param in mylayer.named_parameters():
    print(name, param)      # will print w_tmp,_linear.weight,_linear.bias
create_tensor ( name=None, persistable=None, dtype=None )

create_tensor

Create Tensor for this layer.

Parameters
  • name (str, optional) – name of the tensor. Please refer to Name . Default: None

  • persistable (bool, optional) – if set this tensor persistable. Default: False

  • dtype (str, optional) – data type of this parameter. If set str, it can be “bool”, “float16”, “float32”, “float64”, “int8”, “int16”, “int32”, “int64”, “uint8” or “uint16”. If set None, it will be “float32”. Default: None

Returns

Tensor, created Tensor.

Examples

import paddle

class MyLinear(paddle.nn.Layer):
    def __init__(self,
                in_features,
                out_features):
        super(MyLinear, self).__init__()
        self.linear = paddle.nn.Linear( 10, 10)

        self.back_var = self.create_tensor(name = "linear_tmp_0", dtype=self._dtype)

    def forward(self, input):
        out = self.linear(input)
        paddle.assign( out, self.back_var)

        return out
create_variable ( name=None, persistable=None, dtype=None )

create_variable

Create Tensor for this layer.

Parameters
  • name (str, optional) – name of the tensor. Please refer to Name . Default: None

  • persistable (bool, optional) – if set this tensor persistable. Default: False

  • dtype (str, optional) – data type of this parameter. If set str, it can be “bool”, “float16”, “float32”, “float64”,”int8”, “int16”, “int32”, “int64”, “uint8” or “uint16”. If set None, it will be “float32”. Default: None

Returns

Tensor, created Tensor.

Examples

import paddle

class MyLinear(paddle.nn.Layer):
    def __init__(self,
                in_features,
                out_features):
        super(MyLinear, self).__init__()
        self.linear = paddle.nn.Linear( 10, 10)

        self.back_var = self.create_variable(name = "linear_tmp_0", dtype=self._dtype)

    def forward(self, input):
        out = self.linear(input)
        paddle.assign( out, self.back_var)

        return out
eval ( )

eval

Sets this Layer and all its sublayers to evaluation mode. This only effects certain modules like Dropout and BatchNorm.

Returns

None

Example::
import paddle

class MyLayer(paddle.nn.Layer):
    def __init__(self):
        super(MyLayer, self).__init__()
        self._linear = paddle.nn.Linear(1, 1)
        self._dropout = paddle.nn.Dropout(p=0.5)

    def forward(self, input):
        temp = self._linear(input)
        temp = self._dropout(temp)
        return temp

x = paddle.randn([10, 1], 'float32')
mylayer = MyLayer()
mylayer.eval()  # set mylayer._dropout to eval mode
out = mylayer(x)
print(out)
extra_repr ( )

extra_repr

Extra representation of this layer, you can have custom implementation of your own layer.

full_name ( )

full_name

Full name for this layer, composed by name_scope + “/” + MyLayer.__class__.__name__

Returns

full name of this layer.

Return type

str

Example::
import paddle

class LinearNet(paddle.nn.Layer):
    def __init__(self):
        super(LinearNet, self).__init__(name_scope = "demo_linear_net")
        self._linear = paddle.nn.Linear(1, 1)

    def forward(self, x):
        return self._linear(x)

linear_net = LinearNet()
print(linear_net.full_name())   # demo_linear_net_0
named_buffers ( prefix='', include_sublayers=True )

named_buffers

Returns an iterator over all buffers in the Layer, yielding tuple of name and Tensor.

Parameters
  • prefix (str, optional) – Prefix to prepend to all buffer names. Default: ‘’.

  • include_sublayers (bool, optional) – Whether include the buffers of sublayers. If True, also include the named buffers from sublayers. Default: True.

Yields

(string, Tensor) – Tuple of name and tensor

Examples

import numpy as np
import paddle

fc1 = paddle.nn.Linear(10, 3)
buffer1 = paddle.to_tensor(np.array([0]).astype("float32"))
# register a tensor as buffer by specific `persistable`
fc1.register_buffer("buf_name_1", buffer1, persistable=True)

fc2 = paddle.nn.Linear(3, 10)
buffer2 = paddle.to_tensor(np.array([1]).astype("float32"))
# register a buffer by assigning an attribute with Tensor.
# The `persistable` can only be False by this way.
fc2.buf_name_2 = buffer2

model = paddle.nn.Sequential(fc1, fc2)

# get all named buffers
for name, buffer in model.named_buffers():
    print(name, buffer)
named_children ( )

named_children

Returns an iterator over immediate children layers, yielding both the name of the layer as well as the layer itself.

Yields

(string, Layer) – Tuple containing a name and child layer

Examples

import paddle

linear1 = paddle.nn.Linear(10, 3)
linear2 = paddle.nn.Linear(3, 10, bias_attr=False)
model = paddle.nn.Sequential(linear1, linear2)
for prefix, layer in model.named_children():
    print(prefix, layer)
    # ('0', <paddle.nn.layer.common.Linear object at 0x7fb61ed85830>)
    # ('1', <paddle.nn.layer.common.Linear object at 0x7fb61ed85950>)
named_parameters ( prefix='', include_sublayers=True )

named_parameters

Returns an iterator over all parameters in the Layer, yielding tuple of name and parameter.

Parameters
  • prefix (str, optional) – Prefix to prepend to all parameter names. Default: ‘’.

  • include_sublayers (bool, optional) – Whether include the parameters of sublayers. If True, also include the named parameters from sublayers. Default: True.

Yields

(string, Parameter) – Tuple of name and Parameter

Examples

import paddle

fc1 = paddle.nn.Linear(10, 3)
fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
model = paddle.nn.Sequential(fc1, fc2)
for name, param in model.named_parameters():
    print(name, param)
named_sublayers ( prefix='', include_self=False, layers_set=None )

named_sublayers

Returns an iterator over all sublayers in the Layer, yielding tuple of name and sublayer. The duplicate sublayer will only be yielded once.

Parameters
  • prefix (str, optional) – Prefix to prepend to all parameter names. Default: ‘’.

  • include_self (bool, optional) – Whether include the Layer itself. Default: False.

  • layers_set (set, optional) – The set to record duplicate sublayers. Default: None.

Yields

(string, Layer) – Tuple of name and Layer

Examples

import paddle

fc1 = paddle.nn.Linear(10, 3)
fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
model = paddle.nn.Sequential(fc1, fc2)
for prefix, layer in model.named_sublayers():
    print(prefix, layer)
parameters ( include_sublayers=True )

parameters

Returns a list of all Parameters from current layer and its sub-layers.

Returns

a list of Parameters.

Return type

list of Tensor

Examples


import paddle

linear = paddle.nn.Linear(1,1) print(linear.parameters()) # print linear_0.w_0 and linear_0.b_0

register_buffer ( name, tensor, persistable=True )

register_buffer

Registers a tensor as buffer into the layer.

buffer is a non-trainable tensor and will not be updated by optimizer, but is necessary for evaluation and inference. For example, the mean and variance in BatchNorm layers. The registered buffer is persistable by default, and will be saved into state_dict alongside parameters. If set persistable=False, it registers a non-persistable buffer, so that it will not be a part of state_dict .

Buffers can be accessed as attributes using given names.

Parameters
  • name (string) – name of the buffer. The buffer can be accessed from this layer using the given name

  • tensor (Tensor) – the tensor to be registered as buffer.

  • persistable (bool) – whether the buffer is part of this layer’s state_dict.

Returns

None

Examples

import numpy as np
import paddle

linear = paddle.nn.Linear(10, 3)
value = np.array([0]).astype("float32")
buffer = paddle.to_tensor(value)
linear.register_buffer("buf_name", buffer, persistable=True)

# get the buffer by attribute.
print(linear.buf_name)
register_forward_post_hook ( hook )

register_forward_post_hook

Register a forward post-hook for Layer. The hook will be called after forward function has been computed.

It should have the following form, input and output of the hook is input and output of the Layer respectively. User can use forward post-hook to change the output of the Layer or perform information statistics tasks on the Layer.

hook(Layer, input, output) -> None or modified output

Parameters

hook (function) – a function registered as a forward post-hook

Returns

a HookRemoveHelper object that can be used to remove the added hook by calling hook_remove_helper.remove() .

Return type

HookRemoveHelper

Examples

import paddle
import numpy as np

# the forward_post_hook change the output of the layer: output = output * 2
def forward_post_hook(layer, input, output):
    # user can use layer, input and output for information statistis tasks

    # change the output
    return output * 2

linear = paddle.nn.Linear(13, 5)

# register the hook
forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook)

value1 = np.arange(26).reshape(2, 13).astype("float32")
in1 = paddle.to_tensor(value1)

out0 = linear(in1)

# remove the hook
forward_post_hook_handle.remove()

out1 = linear(in1)

# hook change the linear's output to output * 2, so out0 is equal to out1 * 2.
assert (out0.numpy() == (out1.numpy()) * 2).any()
register_forward_pre_hook ( hook )

register_forward_pre_hook

Register a forward pre-hook for Layer. The hook will be called before forward function has been computed.

It should have the following form, input of the hook is input of the Layer, hook can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned(unless that value is already a tuple). User can use forward pre-hook to change the input of the Layer or perform information statistics tasks on the Layer.

hook(Layer, input) -> None or modified input

Parameters

hook (function) – a function registered as a forward pre-hook

Returns

a HookRemoveHelper object that can be used to remove the added hook by calling hook_remove_helper.remove() .

Return type

HookRemoveHelper

Examples

import paddle
import numpy as np

# the forward_post_hook change the input of the layer: input = input * 2
def forward_pre_hook(layer, input):
    # user can use layer and input for information statistis tasks

    # change the input
    input_return = (input[0] * 2)
    return input_return

linear = paddle.nn.Linear(13, 5)

# register the hook
forward_pre_hook_handle = linear.register_forward_pre_hook(forward_pre_hook)

value0 = np.arange(26).reshape(2, 13).astype("float32")
in0 = paddle.to_tensor(value0)
out0 = linear(in0)

# remove the hook
forward_pre_hook_handle.remove()

value1 = value0 * 2
in1 = paddle.to_tensor(value1)
out1 = linear(in1)

# hook change the linear's input to input * 2, so out0 is equal to out1.
assert (out0.numpy() == out1.numpy()).any()
set_state_dict ( state_dict, use_structured_name=True )

set_state_dict

Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict

Parameters
  • state_dict (dict) – Dict contains all the parameters and persistable buffers.

  • use_structured_name (bool, optional) – If true, use structured name as key, otherwise, use parameter or buffer name as key. Default: True

Returns

None

Examples

import paddle
import paddle.distributed as dist

dist.init_parallel_env()

emb = paddle.nn.Embedding(10, 10)
emb = fluid.dygraph.DataParallel(emb)

state_dict = emb.state_dict()
paddle.save(state_dict, "paddle_dy.pdparams")

para_state_dict = paddle.load("paddle_dy.pdparams")
emb.set_state_dict(para_state_dict)
sublayers ( include_self=False )

sublayers

Returns a list of sub layers.

Parameters

include_self (bool, optional) – Whether return self as sublayers. Default: False

Returns

a list of sub layers.

Return type

list of Layer

Examples

import paddle

class MyLayer(paddle.nn.Layer):
    def __init__(self):
        super(MyLayer, self).__init__()
        self._linear = paddle.nn.Linear(1, 1)
        self._dropout = paddle.nn.Dropout(p=0.5)

    def forward(self, input):
        temp = self._linear(input)
        temp = self._dropout(temp)
        return temp

mylayer = MyLayer()
print(mylayer.sublayers())  # [<paddle.nn.layer.common.Linear object at 0x7f44b58977d0>, <paddle.nn.layer.common.Dropout object at 0x7f44b58978f0>]
to ( device=None, dtype=None, blocking=None )

to

Cast the parameters and buffers of Layer by the give device, dtype and blocking.

Parameters
  • device (str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional) – The device of the Layer which want to be stored.

  • None (If) –

  • string (the device is the same with the original Tensor. If device is) –

  • cpu (it can be) –

  • xpu:x (gpu:x and) –

  • the (where x is) –

  • Default (index of the GPUs or XPUs.) – None.

  • dtype (str|numpy.dtype|paddle.dtype|None, optional) – The type of the data. If None, the dtype is the same with the original Tensor. Default: None.

  • blocking (bool|None, optional) – If False and the source is in pinned memory, the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None.

Returns

self

Examples

# required: skip
import paddle

linear=paddle.nn.Linear(2, 2)
linear.weight
#Parameter containing:
#Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
#       [[-0.32770029,  0.38653070],
#        [ 0.46030545,  0.08158520]])

linear.to(dtype='float64')
linear.weight
#Tenor(shape=[2, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=False,
#       [[-0.32770029,  0.38653070],
#        [ 0.46030545,  0.08158520]])

linear.to(device='cpu')
linear.weight
#Tensor(shape=[2, 2], dtype=float64, place=CPUPlace, stop_gradient=False,
#       [[-0.32770029,  0.38653070],
#        [ 0.46030545,  0.08158520]])
linear.to(device=paddle.CUDAPinnedPlace(), blocking=False)
linear.weight
#Tensor(shape=[2, 2], dtype=float64, place=CUDAPinnedPlace, stop_gradient=False,
#       [[-0.04989364, -0.56889004],
#        [ 0.33960250,  0.96878713]])
to_static_state_dict ( destination=None, include_sublayers=True, structured_name_prefix='' )

to_static_state_dict

Get all parameters and buffers of current layer and its sub-layers. And set them into a dict

Parameters
  • destination (dict, optional) – If provide, all the parameters and persistable buffers will be set to this dict . Default: None

  • include_sublayers (bool, optional) – If true, also include the parameters and persistable buffers from sublayers. Default: True

Retruns:

dict: a dict contains all the parameters and persistable buffers.

Examples

import paddle

emb = paddle.nn.Embedding(10, 10)

state_dict = emb.to_static_state_dict()
paddle.save( state_dict, "paddle_dy.pdparams")
train ( )

train

Sets this Layer and all its sublayers to training mode. This only effects certain modules like Dropout and BatchNorm.

Returns

None

Example::
import paddle

class MyLayer(paddle.nn.Layer):
    def __init__(self):
        super(MyLayer, self).__init__()
        self._linear = paddle.nn.Linear(1, 1)
        self._dropout = paddle.nn.Dropout(p=0.5)

    def forward(self, input):
        temp = self._linear(input)
        temp = self._dropout(temp)
        return temp

x = paddle.randn([10, 1], 'float32')
mylayer = MyLayer()
mylayer.eval()  # set mylayer._dropout to eval mode
out = mylayer(x)
mylayer.train()  # set mylayer._dropout to train mode
out = mylayer(x)
set_dict ( state_dict, use_structured_name=True )

set_dict

Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict

Parameters
  • state_dict (dict) – Dict contains all the parameters and persistable buffers.

  • use_structured_name (bool, optional) – If true, use structured name as key, otherwise, use parameter or buffer name as key. Default: True

Returns

None

Examples

import paddle
import paddle.distributed as dist

dist.init_parallel_env()

emb = paddle.nn.Embedding(10, 10)
emb = fluid.dygraph.DataParallel(emb)

state_dict = emb.state_dict()
paddle.save(state_dict, "paddle_dy.pdparams")

para_state_dict = paddle.load("paddle_dy.pdparams")
emb.set_state_dict(para_state_dict)
load_dict ( state_dict, use_structured_name=True )

load_dict

Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict

Parameters
  • state_dict (dict) – Dict contains all the parameters and persistable buffers.

  • use_structured_name (bool, optional) – If true, use structured name as key, otherwise, use parameter or buffer name as key. Default: True

Returns

None

Examples

import paddle
import paddle.distributed as dist

dist.init_parallel_env()

emb = paddle.nn.Embedding(10, 10)
emb = fluid.dygraph.DataParallel(emb)

state_dict = emb.state_dict()
paddle.save(state_dict, "paddle_dy.pdparams")

para_state_dict = paddle.load("paddle_dy.pdparams")
emb.set_state_dict(para_state_dict)