case¶
- paddle.static.nn. case ( pred_fn_pairs, default=None, name=None ) [source]
-
- Api_attr
-
Static Graph
This operator works like an if-elif-elif-else chain.
- Parameters
-
pred_fn_pairs (list|tuple) – A list or tuple of (pred, fn) pairs.
pred
is a boolean Tensor whose numel should be 1 (shape [] or shape [1]),fn
is a callable. All callables return the same structure of Tensors.default (callable, optional) – Callable that returns a structure of Tensors.
name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.
- Returns
-
Tensors returned by the callable from the first pair whose pred is True, or Tensors returned by
default
if no pred inpred_fn_pairs
is True anddefault
is not None, or Tensors returned by the last callable inpred_fn_pairs
if no pred inpred_fn_pairs
is True anddefault
is None. - Return type
-
Tensor|list(Tensor)
- Raises
-
TypeError – If the type of
pred_fn_pairs
is not list or tuple.TypeError – If the type of elements in
pred_fn_pairs
is not tuple.TypeError – If the size of tuples in
pred_fn_pairs
is not 2.TypeError – If the first element of 2-tuple in
pred_fn_pairs
is not a Tensor.TypeError – If the second element of 2-tuple in
pred_fn_pairs
is not callable.TypeError – If
default
is not None but it is not callable.
Examples
import paddle paddle.enable_static() def fn_1(): return paddle.full(shape=[1, 2], dtype='float32', fill_value=1) def fn_2(): return paddle.full(shape=[2, 2], dtype='int32', fill_value=2) def fn_3(): return paddle.full(shape=[3], dtype='int32', fill_value=3) main_program = paddle.static.default_startup_program() startup_program = paddle.static.default_main_program() with paddle.static.program_guard(main_program, startup_program): x = paddle.full(shape=[1], dtype='float32', fill_value=0.3) y = paddle.full(shape=[1], dtype='float32', fill_value=0.1) z = paddle.full(shape=[1], dtype='float32', fill_value=0.2) pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3 pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1 pred_3 = paddle.equal(x, y) # false: 0.3 == 0.1 # Call fn_1 because pred_1 is True out_1 = paddle.static.nn.case( pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3) # Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called. # because fn_3 is the last callable in pred_fn_pairs. out_2 = paddle.static.nn.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)]) exe = paddle.static.Executor(paddle.CPUPlace()) res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2]) print(res_1) # [[1. 1.]] print(res_2) # [3 3 3]