switch_case¶
- paddle.static.nn. switch_case ( branch_index, branch_fns, default=None, name=None ) [source]
-
- Api_attr
-
Static Graph
This operator is like a C++ switch/case statement.
- Parameters
-
branch_index (Tensor) – A Tensor whose numel should be 1 (shape [] or shape [1]) to specify which branch to execute. The data type is
int32
,int64
oruint8
.branch_fns (dict|list|tuple) – If it’s a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it’s a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors.
default (callable, optional) – Callable that returns a structure of Tensors.
name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.
- Returns
-
Tensors returned by the callable specified by
branch_index
inbranch_fns
, or Tensors returned bydefault
ifdefault
is not None and no index matches inbranch_fns
, or Tensors returned by the callable with the max index inbranch_fns
ifdefault
is None and no index matches inbranch_fns
. - Return type
-
Tensor|list(Tensor)
- Raises
-
TypeError – If the type of
branch_index
is not Tensor.TypeError – If the data type of
branch_index
is notint32
,int64
oruint8
.TypeError – If the type of
branch_fns
is not dict, list or tuple.TypeError – If the elements of
branch_fns
is not 2-tuple.TypeError – If the first element of 2-tuple in
branch_fns
is not integer.ValueError – If the first element of 2-tuple in
branch_fns
is not unique.TypeError – If the second element of 2-tuple in
branch_fns
is not callable.TypeError – If
default
is not None but it is not callable.
Examples
import paddle paddle.enable_static() def fn_1(): return paddle.full(shape=[1, 2], dtype='float32', fill_value=1) def fn_2(): return paddle.full(shape=[2, 2], dtype='int32', fill_value=2) def fn_3(): return paddle.full(shape=[3], dtype='int32', fill_value=3) main_program = paddle.static.default_startup_program() startup_program = paddle.static.default_main_program() with paddle.static.program_guard(main_program, startup_program): index_1 = paddle.full(shape=[1], dtype='int32', fill_value=1) index_2 = paddle.full(shape=[1], dtype='int32', fill_value=2) out_1 = paddle.static.nn.switch_case( branch_index=index_1, branch_fns={1: fn_1, 2: fn_2}, default=fn_3) out_2 = paddle.static.nn.switch_case( branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)], default=fn_3) # Argument default is None and no index matches. fn;,,_3 will be called because of the max index 7. out_3 = paddle.static.nn.switch_case( branch_index=index_2, branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)]) exe = paddle.static.Executor(paddle.CPUPlace()) res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3]) print(res_1) # [[1. 1.]] print(res_2) # [[2 2] [2 2]] print(res_3) # [3 3 3]