Executor

class paddle.static. Executor ( place=None ) [source]
Api_attr

Static Graph

An Executor in Python, supports single/multiple-GPU running, and single/multiple-CPU running.

Parameters

place (paddle.CPUPlace()|paddle.CUDAPlace(n)|str|None) – This parameter represents which device the executor runs on. When this parameter is None, PaddlePaddle will set the default device according to its installation version. If Paddle is CPU version, the default device would be set to CPUPlace() . If Paddle is GPU version, the default device would be set to CUDAPlace(0) . Default is None. If place is string, it can be cpu, and gpu:x, where x is the index of the GPUs.

Returns

Executor

Examples

import paddle
import numpy
import os

# Executor is only used in static graph mode
paddle.enable_static()

# Set place explicitly.
# use_cuda = True
# place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
# exe = paddle.static.Executor(place)

# If you don't set place, PaddlePaddle sets the default device.
exe = paddle.static.Executor()

train_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_program):
    data = paddle.static.data(name='X', shape=[None, 1], dtype='float32')
    hidden = paddle.static.nn.fc(data, 10)
    loss = paddle.mean(hidden)
    paddle.optimizer.SGD(learning_rate=0.01).minimize(loss)

# Run the startup program once and only once.
# Not need to optimize/compile the startup program.
exe.run(startup_program)

# Run the main program directly without compile.
x = numpy.random.random(size=(10, 1)).astype('float32')
loss_data, = exe.run(train_program, feed={"X": x}, fetch_list=[loss.name])

# Or, compiled the program and run. See `CompiledProgram`
# for more details.
# NOTE: If you use CPU to run the program or Paddle is
# CPU version, you need to specify the CPU_NUM, otherwise,
# PaddlePaddle will use all the number of the logic core as
# the CPU_NUM, in that case, the batch size of the input
# should be greater than CPU_NUM, if not, the process will be
# failed by an exception.

# Set place explicitly.
# if not use_cuda:
#     os.environ['CPU_NUM'] = str(2)

# If you don't set place and PaddlePaddle is CPU version
os.environ['CPU_NUM'] = str(2)

compiled_prog = paddle.static.CompiledProgram(
    train_program).with_data_parallel(loss_name=loss.name)
loss_data, = exe.run(compiled_prog, feed={"X": x}, fetch_list=[loss.name])
close ( )

close

Close the executor. This interface is used for distributed training (PServers mode). This executor can not be used after calling the interface, because this interface releases resources associated with the current Trainer.

Returns

None

Examples

import paddle

cpu = paddle.CPUPlace()
exe = paddle.static.Executor(cpu)
# execute training or testing
exe.close()
run ( program=None, feed=None, fetch_list=None, feed_var_name='feed', fetch_var_name='fetch', scope=None, return_numpy=True, use_program_cache=False, return_merged=True, use_prune=False )

run

Run the specified Program or CompiledProgram. It should be noted that the executor will execute all the operators in Program or CompiledProgram without pruning some operators of the Program or CompiledProgram according to fetch_list. And you could specify the scope to store the Tensor during the executor running if the scope is not set, the executor will use the global scope, i.e. paddle.static.global_scope().

Parameters
  • program (Program|CompiledProgram) – This parameter represents the Program or CompiledProgram to be executed. If this parameter is not provided, that parameter is None, the program will be set to paddle.static.default_main_program(). The default is None.

  • feed (list|dict) – This parameter represents the input Tensors of the model. If it is single card training, the feed is dict type, and if it is multi-card training, the parameter feed can be dict or list of Tensors. If the parameter type is dict, the data in the feed will be split and sent to multiple devices (CPU/GPU), that is to say, the input data will be evenly sent to different devices, so you should make sure the number of samples of the current mini-batch must be greater than the number of places; if the parameter type is list, those data are copied directly to each device, so the length of this list should be equal to the number of places. The default is None.

  • fetch_list (list) – This parameter represents the Tensors that need to be returned after the model runs. The default is None.

  • feed_var_name (str) – This parameter represents the name of the input Tensor of the feed operator. The default is “feed”.

  • fetch_var_name (str) – This parameter represents the name of the output Tensor of the fetch operator. The default is “fetch”.

  • scope (Scope) – the scope used to run this program, you can switch it to different scope. default is paddle.static.global_scope()

  • return_numpy (bool) – This parameter indicates whether convert the fetched Tensors (the Tensor specified in the fetch list) to numpy.ndarray. if it is False, the type of the return value is a list of LoDTensor. The default is True.

  • use_program_cache (bool) – This parameter indicates whether the input Program is cached. If the parameter is True, the model may run faster in the following cases: the input program is paddle.static.Program, and the parameters(program, feed Tensor name and fetch_list Tensor) of this interface remains unchanged during running. The default is False.

  • return_merged (bool) – This parameter indicates whether fetched Tensors (the Tensors specified in the fetch list) should be merged according to the execution device dimension. If return_merged is False, the type of the return value is a two-dimensional list of Tensor / LoDTensorArray ( return_numpy is False) or a two-dimensional list of numpy.ndarray ( return_numpy is True). If return_merged is True, the type of the return value is an one-dimensional list of Tensor / LoDTensorArray ( return_numpy is False) or an one-dimensional list of numpy.ndarray ( return_numpy is True). Please see Examples 2 for more details. If the lengths of fetched results are variant, please set return_merged as False, which denotes that the fetched results will not be merged. The default is True, but it is just for the compatibility, and may use False as default value in the future version.

  • use_prune (bool) – This parameter indicates whether the input Program will be pruned. If the parameter is True, the program will be pruned accroding to the given feed and fetch_list, which means the operators and variables in program that generate feed and are not needed to generate fetch_list will be pruned. The default is False, which means the program will not pruned and all the operators and variables will be executed during running. Note that if the tuple returned from Optimizer.minimize() is passed to fetch_list, use_prune will be overrided to True, and the program will be pruned.

Returns

The fetched result list.

Return type

List

Notes

  1. If it is multi-card running and the feed parameter is dict type, the input data will be evenly sent to different cards. For example, using two GPUs to run the model, the input sample number is 3, that is, [0, 1, 2], the sample number on GPU0 is 1, that is, [0], and the sample number on GPU1 is 2, that is, [1, 2]. If the number of samples is less than the number of devices, the program will throw an exception, so when running the model, you should make sure that the number of samples of the last batch of the data set should be greater than the number of CPU cores or GPU cards, if it is less than, it is recommended that the batch be discarded.

  2. If the number of CPU cores or GPU cards available is greater than 1, the fetch results are spliced together in dimension 0 for the same Tensor values (Tensors in fetch_list) on different devices.

Examples 1:
import paddle
import numpy

# First create the Executor.
paddle.enable_static()
place = paddle.CPUPlace()  # paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)

data = paddle.static.data(name='X', shape=[None, 1], dtype='float32')
hidden = paddle.static.nn.fc(data, 10)
loss = paddle.mean(hidden)
adam = paddle.optimizer.Adam()
adam.minimize(loss)
i = paddle.zeros(shape=[1], dtype='int64')
array = paddle.fluid.layers.array_write(x=loss, i=i)

# Run the startup program once and only once.
exe.run(paddle.static.default_startup_program())

x = numpy.random.random(size=(10, 1)).astype('float32')
loss_val, array_val = exe.run(feed={'X': x},
                              fetch_list=[loss.name, array.name])
print(array_val)
# [array([0.02153828], dtype=float32)]
Examples 2:
import paddle
import numpy as np

# First create the Executor.
paddle.enable_static()
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)

data = paddle.static.data(name='X', shape=[None, 1], dtype='float32')
class_dim = 2
prediction = paddle.static.nn.fc(data, class_dim)
loss = paddle.mean(prediction)
adam = paddle.optimizer.Adam()
adam.minimize(loss)

# Run the startup program once and only once.
exe.run(paddle.static.default_startup_program())
build_strategy = paddle.static.BuildStrategy()
binary = paddle.static.CompiledProgram(
    paddle.static.default_main_program()).with_data_parallel(
        loss_name=loss.name, build_strategy=build_strategy)
batch_size = 6
x = np.random.random(size=(batch_size, 1)).astype('float32')

# Set return_merged as False to fetch unmerged results:
unmerged_prediction, = exe.run(binary,
                               feed={'X': x},
                               fetch_list=[prediction.name],
                               return_merged=False)
# If the user uses two GPU cards to run this python code, the printed result will be
# (2, 3, class_dim). The first dimension value of the printed result is the number of used
# GPU cards, and the second dimension value is the quotient of batch_size and the
# number of used GPU cards.
print("The unmerged prediction shape: {}".format(
    np.array(unmerged_prediction).shape))
print(unmerged_prediction)

# Set return_merged as True to fetch merged results:
merged_prediction, = exe.run(binary,
                             feed={'X': x},
                             fetch_list=[prediction.name],
                             return_merged=True)
# If the user uses two GPU cards to run this python code, the printed result will be
# (6, class_dim). The first dimension value of the printed result is the batch_size.
print("The merged prediction shape: {}".format(
    np.array(merged_prediction).shape))
print(merged_prediction)

# Out:
# The unmerged prediction shape: (2, 3, 2)
# [array([[-0.37620035, -0.19752218],
#        [-0.3561043 , -0.18697084],
#        [-0.24129935, -0.12669306]], dtype=float32), array([[-0.24489994, -0.12858354],
#        [-0.49041364, -0.25748932],
#        [-0.44331917, -0.23276259]], dtype=float32)]
# The merged prediction shape: (6, 2)
# [[-0.37789783 -0.19921964]
#  [-0.3577645  -0.18863106]
#  [-0.24274671 -0.12814042]
#  [-0.24635398 -0.13003758]
#  [-0.49232286 -0.25939852]
#  [-0.44514108 -0.2345845 ]]
infer_from_dataset ( program=None, dataset=None, scope=None, thread=0, debug=False, fetch_list=None, fetch_info=None, print_period=100, fetch_handler=None )

infer_from_dataset

Infer from a pre-defined Dataset. Dataset is defined in paddle.fluid.dataset. Given a program, either a program or compiled program, infer_from_dataset will consume all data samples in dataset. Input scope can be given by users. By default, scope is global_scope(). The total number of thread run in training is thread. Thread number used in training will be minimum value of threadnum in Dataset and the value of thread in this interface. Debug can be set so that executor will display Run-Time for all operators and the throughputs of current infer task.

The document of infer_from_dataset is almost the same as train_from_dataset, except that in distributed training, push gradients will be disabled in infer_from_dataset. infer_from_dataset() can be used for evaluation in multi-threadvery easily.

Parameters
  • program (Program|CompiledProgram) – the program that needs to be run, if not provided, then default_main_program (not compiled) will be used.

  • dataset (paddle.fluid.Dataset) – dataset created outside this function, a user should provide a well-defined dataset before calling this function. Please check the document of Dataset if needed. default is None

  • scope (Scope) – the scope used to run this program, you can switch it to different scope for each run. default is global_scope

  • thread (int) – number of thread a user wants to run in this function. Default is 0, which means using thread num of dataset

  • debug (bool) – whether a user wants to run infer_from_dataset, default is False

  • fetch_list (Tensor List) – fetch Tensor list, each Tensor will be printed during training, default is None

  • fetch_info (String List) – print information for each Tensor, default is None

  • print_period (int) – the number of mini-batches for each print, default is 100

  • fetch_handler (FetchHandler) – a user define class for fetch output.

Returns

None

Examples

import paddle

paddle.enable_static()
place = paddle.CPUPlace()  # you can set place = paddle.CUDAPlace(0) to use gpu
exe = paddle.static.Executor(place)
x = paddle.static.data(name="x", shape=[None, 10, 10], dtype="int64")
y = paddle.static.data(name="y", shape=[None, 1], dtype="int64", lod_level=1)
dataset = paddle.fluid.DatasetFactory().create_dataset()
dataset.set_use_var([x, y])
dataset.set_thread(1)
# you should set your own filelist, e.g. filelist = ["dataA.txt"]
filelist = []
dataset.set_filelist(filelist)
exe.run(paddle.static.default_startup_program())
exe.infer_from_dataset(program=paddle.static.default_main_program(),
                       dataset=dataset)
train_from_dataset ( program=None, dataset=None, scope=None, thread=0, debug=False, fetch_list=None, fetch_info=None, print_period=100, fetch_handler=None )

train_from_dataset

Train from a pre-defined Dataset. Dataset is defined in paddle.fluid.dataset. Given a program, either a program or compiled program, train_from_dataset will consume all data samples in dataset. Input scope can be given by users. By default, scope is global_scope(). The total number of thread run in training is thread. Thread number used in training will be minimum value of threadnum in Dataset and the value of thread in this interface. Debug can be set so that executor will display Run-Time for all operators and the throughputs of current training task.

Note: train_from_dataset will destroy all resources created within executor for each run.

Parameters
  • program (Program|CompiledProgram) – the program that needs to be run, if not provided, then default_main_program (not compiled) will be used.

  • dataset (paddle.fluid.Dataset) – dataset created outside this function, a user should provide a well-defined dataset before calling this function. Please check the document of Dataset if needed.

  • scope (Scope) – the scope used to run this program, you can switch it to different scope for each run. default is global_scope

  • thread (int) – number of thread a user wants to run in this function. Default is 0, which means using thread num of dataset

  • debug (bool) – whether a user wants to run train_from_dataset

  • fetch_list (Tensor List) – fetch Tensor list, each variable will be printed during training

  • fetch_info (String List) – print information for each Tensor, its length should be equal to fetch_list

  • print_period (int) – the number of mini-batches for each print, default is 100

  • fetch_handler (FetchHandler) – a user define class for fetch output.

Returns

None

Examples

import paddle

paddle.enable_static()
place = paddle.CPUPlace() # you can set place = paddle.CUDAPlace(0) to use gpu
exe = paddle.static.Executor(place)
x = paddle.static.data(name="x", shape=[None, 10, 10], dtype="int64")
y = paddle.static.data(name="y", shape=[None, 1], dtype="int64", lod_level=1)
dataset = paddle.fluid.DatasetFactory().create_dataset()
dataset.set_use_var([x, y])
dataset.set_thread(1)
# you should set your own filelist, e.g. filelist = ["dataA.txt"]
filelist = []
dataset.set_filelist(filelist)
exe.run(paddle.static.default_startup_program())
exe.train_from_dataset(program=paddle.static.default_main_program(),
                       dataset=dataset)

Used in the guide/tutorials