ParameterDict¶
参数字典容器。此容器的行为类似于 Python 字典,但它包含的参数将被正确地注册和添加。
参数¶
parameters (ParameterDict | Mapping[str, Tensor] | Sequence[tuple[str, Tensor]],可选) - 可迭代的 Parameters,默认值为 None。
返回¶
无
代码示例¶
>>> import paddle
>>> class MyLayer(paddle.nn.Layer):
... def __init__(self, num_stacked_param):
... super().__init__()
... # create ParameterDict with iterable Parameters
... self.params = paddle.nn.ParameterDict(
... {f"t{i}": paddle.create_parameter(shape=[2, 2], dtype='float32') for i in range(num_stacked_param)})
...
... def forward(self, x):
... for i, key in enumerate(self.params):
... x = paddle.matmul(x, self.params[key])
... return x
...
>>> x = paddle.uniform(shape=[5, 2], dtype='float32')
>>> num_stacked_param = 4
>>> model = MyLayer(num_stacked_param)
>>> print(len(model.params))
4
>>> res = model(x)
>>> print(res.shape)
[5, 2]
>>> replaced_param = paddle.create_parameter(shape=[2, 3], dtype='float32')
>>> model.params['t3'] = replaced_param # replace t3 param
>>> res = model(x)
>>> print(res.shape)
[5, 3]
>>> model.params['t4'] = paddle.create_parameter(shape=[3, 4], dtype='float32') # append param
>>> print(len(model.params))
5
>>> res = model(x)
>>> print(res.shape)
[5, 4]