ReshapeTransform¶
ReshapeTransform
将输入 Tensor 的事件形状 in_event_shape
改变为 out_event_shape
。其中,in_event_shape
、out_event_shape
需要包含相同的元素个数。
参数¶
in_event_shape (Sequence[int]) - Reshape 前的事件形状。
out_event_shape (float|Tensor) - Reshape 后的事件形状。
代码示例¶
import paddle
x = paddle.ones((1,2,3))
reshape_transform = paddle.distribution.ReshapeTransform((2, 3), (3, 2))
print(reshape_transform.forward_shape((1,2,3)))
# (5, 2, 6)
print(reshape_transform.forward(x))
# Tensor(shape=[1, 3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[[1., 1.],
# [1., 1.],
# [1., 1.]]])
print(reshape_transform.inverse(reshape_transform.forward(x)))
# Tensor(shape=[1, 2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[[1., 1., 1.],
# [1., 1., 1.]]])
print(reshape_transform.forward_log_det_jacobian(x))
# Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# 0.)
方法¶
forward(x)¶
计算正变换 \(y=f(x)\) 的结果。
参数
x (Tensor) - 正变换输入参数,通常为 Distribution 的随机采样结果。
返回
Tensor,正变换的计算结果。
forward_log_det_jacobian(x)¶
计算正变换雅可比行列式绝对值的对数。
如果变换不是一一映射,则雅可比矩阵不存在,抛出 NotImplementedError
。
参数
x (Tensor) - 输入参数。
返回
Tensor,正变换雅可比行列式绝对值的对数。
inverse_log_det_jacobian(y)¶
计算逆变换雅可比行列式绝对值的对数。
与 forward_log_det_jacobian
互为负数。
参数
y (Tensor) - 输入参数。
返回
Tensor,逆变换雅可比行列式绝对值的对数。