jvp¶
计算函数 func
在 xs
处的雅可比矩阵与向量 v
的乘积。
警告
该 API 目前为 Beta 版本,函数签名在未来版本可能发生变化。
参数¶
func (Callable) - Python 函数,输入参数为
xs
,输出为 Tensor 或 Tensor 序列。xs (Tensor|Sequence[Tensor]) - 函数
func
的输入参数,数据类型为 Tensor 或 Tensor 序列。v (Tensor|Sequence[Tensor]|None,可选) - 用于计算
jvp
的输入向量,形状要求 与xs
一致。默认值为None
,即相当于形状与xs
一致,值全为 1 的 Tensor 或 Tensor 序列。
返回¶
func_out (Tensor|tuple[Tensor]) - 函数
func(xs)
的输出。jvp (Tensor|tuple[Tensor]) -
jvp
计算结果。
代码示例¶
>>> import paddle
>>> def func(x):
... return paddle.matmul(x, x)
...
>>> x = paddle.ones(shape=[2, 2], dtype='float32')
>>> _, jvp_result = paddle.incubate.autograd.jvp(func, x)
>>> print(jvp_result)
Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False,
[[4., 4.],
[4., 4.]])
>>> v = paddle.to_tensor([[1.0, 0.0], [0.0, 0.0]])
>>> _, jvp_result = paddle.incubate.autograd.jvp(func, x, v)
>>> print(jvp_result)
Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False,
[[2., 1.],
[1., 0.]])