case¶
- paddle.static.nn. case ( pred_fn_pairs, default=None, name=None ) [source]
-
- Api_attr
-
Static Graph
This operator works like an if-elif-elif-else chain.
- Parameters
-
pred_fn_pairs (list|tuple) – A list or tuple of (pred, fn) pairs.
pred
is a boolean Tensor whose numel should be 1 (shape [] or shape [1]),fn
is a callable. All callables return the same structure of Tensors.default (callable, optional) – Callable that returns a structure of Tensors.
name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.
- Returns
-
Tensors returned by the callable from the first pair whose pred is True, or Tensors returned by
default
if no pred inpred_fn_pairs
is True anddefault
is not None, or Tensors returned by the last callable inpred_fn_pairs
if no pred inpred_fn_pairs
is True anddefault
is None. - Return type
-
Tensor|list(Tensor)
- Raises
-
TypeError – If the type of
pred_fn_pairs
is not list or tuple.TypeError – If the type of elements in
pred_fn_pairs
is not tuple.TypeError – If the size of tuples in
pred_fn_pairs
is not 2.TypeError – If the first element of 2-tuple in
pred_fn_pairs
is not a Tensor.TypeError – If the second element of 2-tuple in
pred_fn_pairs
is not callable.TypeError – If
default
is not None but it is not callable.
Examples
>>> import paddle >>> paddle.enable_static() >>> def fn_1(): ... return paddle.full(shape=[1, 2], dtype='float32', fill_value=1) >>> def fn_2(): ... return paddle.full(shape=[2, 2], dtype='int32', fill_value=2) >>> def fn_3(): ... return paddle.full(shape=[3], dtype='int32', fill_value=3) >>> main_program = paddle.static.default_startup_program() >>> startup_program = paddle.static.default_main_program() >>> with paddle.static.program_guard(main_program, startup_program): ... x = paddle.full(shape=[1], dtype='float32', fill_value=0.3) ... y = paddle.full(shape=[1], dtype='float32', fill_value=0.1) ... z = paddle.full(shape=[1], dtype='float32', fill_value=0.2) ... pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3 ... pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1 ... pred_3 = paddle.equal(x, y) # false: 0.3 == 0.1 ... # Call fn_1 because pred_1 is True ... out_1 = paddle.static.nn.case( ... pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3) ... # Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called. ... # because fn_3 is the last callable in pred_fn_pairs. ... out_2 = paddle.static.nn.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)]) ... exe = paddle.static.Executor(paddle.CPUPlace()) ... res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2]) ... print(res_1, res_2) [[1. 1.]] [3 3 3]