diff options
Diffstat (limited to 'tests/datasets')
-rw-r--r-- | tests/datasets/ScatterDataset.h | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h index 4ad269ec85..8fd4448d2d 100644 --- a/tests/datasets/ScatterDataset.h +++ b/tests/datasets/ScatterDataset.h @@ -180,7 +180,6 @@ public: // NOTE: Updates/Indices tensors are now batched. // NOTE: indices.shape.x = (updates_batched) ? (src.num_dimensions - updates.num_dimensions) + 2 : (src.num_dimensions - updates.num_dimensions) + 1 // k is the number of batch dimensions - // k = 2 add_config(TensorShape(6U, 5U), TensorShape(6U, 2U, 2U), TensorShape(1U, 2U, 2U), TensorShape(6U, 5U)); add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(5U, 5U, 6U, 2U), TensorShape(3U, 6U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U)); @@ -197,6 +196,18 @@ public: } }; +class SmallScatterScalarDataset final : public ScatterDataset +{ +public: + // batched scalar case + SmallScatterScalarDataset() + { + add_config(TensorShape(6U, 5U), TensorShape(6U), TensorShape(2U, 6U), TensorShape(6U, 5U)); + add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U)); + add_config(TensorShape(3U, 3U, 6U, 5U), TensorShape(6U, 6U), TensorShape(4U, 6U, 6U), TensorShape(3U, 3U, 6U, 5U)); + } +}; + // This dataset is for data types that does not require full testing. It contains selected tests from the above. class SmallScatterMixedDataset final : public ScatterDataset { @@ -205,6 +216,7 @@ public: { add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U)); add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U)); + add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U)); add_config(TensorShape(35U, 4U, 3U, 2U, 2U), TensorShape(35U, 4U), TensorShape(4U, 4U), TensorShape(35U, 4U, 3U, 2U, 2U)); add_config(TensorShape(11U, 3U, 3U, 2U, 4U), TensorShape(11U, 3U, 3U, 4U), TensorShape(2U, 4U), TensorShape(11U, 3U, 3U, 2U, 4U)); add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 2U, 2U), TensorShape(2U, 2U, 2U), TensorShape(6U, 5U, 2U)); |