TensorDataset¶
- class paddle.io. TensorDataset ( tensors ) [source]
-
Dataset defined by a list of tensors.
Each tensor should be in shape of [N, …], while N is the sample number, and each tensor contains a field of sample,
TensorDataset
retrieve each sample by indexing tensors in the 1st dimension.- Parameters
-
tensors (list|tuple) – A list/tuple of tensors with same shape in the 1st dimension.
- Returns
-
a Dataset instance wrapping tensors.
- Return type
-
Dataset
Examples
>>> import numpy as np >>> import paddle >>> from paddle.io import TensorDataset >>> input_np = np.random.random([2, 3, 4]).astype('float32') >>> input = paddle.to_tensor(input_np) >>> label_np = np.random.random([2, 1]).astype('int32') >>> label = paddle.to_tensor(label_np) >>> dataset = TensorDataset([input, label]) >>> for i in range(len(dataset)): ... input, label = dataset[i] ... # do something