switch_case¶
- paddle.static.nn. switch_case ( branch_index, branch_fns, default=None, name=None ) [source]
-
- Api_attr
-
Static Graph
This operator is like a C++ switch/case statement.
- Parameters
-
branch_index (Tensor) – A Tensor whose numel should be 1 (shape [] or shape [1]) to specify which branch to execute. The data type is
int32
,int64
oruint8
.branch_fns (dict|list|tuple) – If it’s a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it’s a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors.
default (callable, optional) – Callable that returns a structure of Tensors.
name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.
- Returns
-
Tensors returned by the callable specified by
branch_index
inbranch_fns
, or Tensors returned bydefault
ifdefault
is not None and no index matches inbranch_fns
, or Tensors returned by the callable with the max index inbranch_fns
ifdefault
is None and no index matches inbranch_fns
. - Return type
-
Tensor|list(Tensor)
- Raises
-
TypeError – If the type of
branch_index
is not Tensor.TypeError – If the data type of
branch_index
is notint32
,int64
oruint8
.TypeError – If the type of
branch_fns
is not dict, list or tuple.TypeError – If the elements of
branch_fns
is not 2-tuple.TypeError – If the first element of 2-tuple in
branch_fns
is not integer.ValueError – If the first element of 2-tuple in
branch_fns
is not unique.TypeError – If the second element of 2-tuple in
branch_fns
is not callable.TypeError – If
default
is not None but it is not callable.
Examples
>>> import paddle >>> paddle.enable_static() >>> def fn_1(): ... return paddle.full(shape=[1, 2], dtype='float32', fill_value=1) >>> def fn_2(): ... return paddle.full(shape=[2, 2], dtype='int32', fill_value=2) >>> def fn_3(): ... return paddle.full(shape=[3], dtype='int32', fill_value=3) >>> startup_program = paddle.static.default_startup_program() >>> main_program = paddle.static.default_main_program() >>> with paddle.static.program_guard(main_program, startup_program): ... index_1 = paddle.full(shape=[1], dtype='int32', fill_value=1) ... index_2 = paddle.full(shape=[1], dtype='int32', fill_value=2) ... ... out_1 = paddle.static.nn.switch_case( ... branch_index=index_1, ... branch_fns={1: fn_1, 2: fn_2}, ... default=fn_3) ... ... out_2 = paddle.static.nn.switch_case( ... branch_index=index_2, ... branch_fns=[(1, fn_1), (2, fn_2)], ... default=fn_3) ... ... # Argument default is None and no index matches. fn_3 will be called because of the max index 7. ... out_3 = paddle.static.nn.switch_case( ... branch_index=index_2, ... branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)]) ... ... exe = paddle.static.Executor(paddle.CPUPlace()) ... res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3]) ... # Variable: fill_constant_1.tmp_0 ... # - message: The content of input layer: ... # - lod: {} ... # - place: Place(cpu) ... # - shape: [2, 3] ... # - layout: NCHW ... # - dtype: int64 ... # - data: [3 3 3 3 3 3] >>> print(res_1) [[1. 1.]] >>> print(res_2) [[2 2] [2 2]] >>> print(res_3) [3 3 3]