aboutsummaryrefslogtreecommitdiff
path: root/tests/datasets
diff options
context:
space:
mode:
Diffstat (limited to 'tests/datasets')
-rw-r--r--tests/datasets/LargeGEMMDataset.h21
-rw-r--r--tests/datasets/ScatterDataset.h104
-rw-r--r--tests/datasets/SmallGEMMDataset.h19
3 files changed, 134 insertions, 10 deletions
diff --git a/tests/datasets/LargeGEMMDataset.h b/tests/datasets/LargeGEMMDataset.h
index 6cdff7f559..e45319ef57 100644
--- a/tests/datasets/LargeGEMMDataset.h
+++ b/tests/datasets/LargeGEMMDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2019, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_LARGE_GEMM_DATASET
-#define ARM_COMPUTE_TEST_LARGE_GEMM_DATASET
+#ifndef ACL_TESTS_DATASETS_LARGEGEMMDATASET_H
+#define ACL_TESTS_DATASETS_LARGEGEMMDATASET_H
#include "tests/datasets/GEMMDataset.h"
@@ -79,7 +79,20 @@ public:
add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U), TensorShape(128U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.3f);
}
};
+
+class LargeAccumulateGEMMDataset final : public GEMMDataset
+{
+public:
+ LargeAccumulateGEMMDataset()
+ {
+ add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 429U), TensorShape(871U, 429U), 1.0f, 0.0f);
+ add_config(TensorShape(1021U, 1U), TensorShape(783U, 1021U), TensorShape(783U, 1U), TensorShape(783U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(1021U, 1U), TensorShape(783U, 1021U), TensorShape(783U, 1U), TensorShape(783U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(941U, 1U), TensorShape(623U, 941U), TensorShape(623U, 1U), TensorShape(623U, 1U), 1.0f, 0.0f);
+ }
+};
+
} // namespace datasets
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_LARGE_GEMM_DATASET */
+#endif // ACL_TESTS_DATASETS_LARGEGEMMDATASET_H
diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h
index d204d17855..8fd4448d2d 100644
--- a/tests/datasets/ScatterDataset.h
+++ b/tests/datasets/ScatterDataset.h
@@ -113,13 +113,113 @@ private:
std::vector<TensorShape> _dst_shapes{};
};
+
+// 1D dataset for simple scatter tests.
class Small1DScatterDataset final : public ScatterDataset
{
public:
Small1DScatterDataset()
{
- add_config(TensorShape(6U), TensorShape(6U), TensorShape(6U), TensorShape(6U));
- add_config(TensorShape(10U), TensorShape(2U), TensorShape(2U), TensorShape(10U));
+ add_config(TensorShape(6U), TensorShape(6U), TensorShape(1U, 6U), TensorShape(6U));
+ add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U));
+ }
+};
+
+// This dataset represents the (m+1)-D updates/dst case.
+class SmallScatterMultiDimDataset final : public ScatterDataset
+{
+public:
+ SmallScatterMultiDimDataset()
+ {
+ // NOTE: Config is src, updates, indices, output.
+ // - In this config, the dim replaced is the final number (largest tensor dimension)
+ // - Largest "updates" dim should match y-dim of indices.
+ // - src/updates/dst should all have same number of dims. Indices should be 2D.
+ add_config(TensorShape(6U, 5U), TensorShape(6U, 2U), TensorShape(1U, 2U), TensorShape(6U, 5U));
+ add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U));
+ add_config(TensorShape(17U, 3U, 2U, 4U), TensorShape(17U, 3U, 2U, 7U), TensorShape(1U, 7U), TensorShape(17U, 3U, 2U, 4U));
+ }
+};
+
+// This dataset represents the (m+1)-D updates tensor, (m+n)-d output tensor cases
+class SmallScatterMultiIndicesDataset final : public ScatterDataset
+{
+public:
+ SmallScatterMultiIndicesDataset()
+ {
+ // NOTE: Config is src, updates, indices, output.
+ // NOTE: indices.shape.x = src.num_dimensions - updates.num_dimensions + 1
+
+ // index length is 2
+ add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 4U), TensorShape(2U, 4U), TensorShape(6U, 5U, 2U));
+ add_config(TensorShape(17U, 3U, 3U, 2U), TensorShape(17U, 3U, 2U), TensorShape(2U, 2U), TensorShape(17U, 3U, 3U, 2U));
+ add_config(TensorShape(11U, 3U, 3U, 2U, 4U), TensorShape(11U, 3U, 3U, 4U), TensorShape(2U, 4U), TensorShape(11U, 3U, 3U, 2U, 4U));
+ add_config(TensorShape(5U, 4U, 3U, 3U, 2U, 4U), TensorShape(5U, 4U, 3U, 3U, 5U), TensorShape(2U, 5U), TensorShape(5U, 4U, 3U, 3U, 2U, 4U));
+
+ // index length is 3
+ add_config(TensorShape(4U, 3U, 2U, 2U), TensorShape(4U, 2U), TensorShape(3U, 2U), TensorShape(4U, 3U, 2U, 2U));
+ add_config(TensorShape(17U, 4U, 3U, 2U, 2U), TensorShape(17U, 4U, 4U), TensorShape(3U, 4U), TensorShape(17U, 4U, 3U, 2U, 2U));
+ add_config(TensorShape(10U, 4U, 5U, 3U, 2U, 2U), TensorShape(10U, 4U, 5U, 3U), TensorShape(3U, 3U), TensorShape(10U, 4U, 5U, 3U, 2U, 2U));
+
+ // index length is 4
+ add_config(TensorShape(35U, 4U, 3U, 2U, 2U), TensorShape(35U, 4U), TensorShape(4U, 4U), TensorShape(35U, 4U, 3U, 2U, 2U));
+ add_config(TensorShape(10U, 4U, 5U, 3U, 2U, 2U), TensorShape(10U, 4U, 3U), TensorShape(4U, 3U), TensorShape(10U, 4U, 5U, 3U, 2U, 2U));
+
+ // index length is 5
+ add_config(TensorShape(10U, 4U, 5U, 3U, 2U, 2U), TensorShape(10U, 3U), TensorShape(5U, 3U), TensorShape(10U, 4U, 5U, 3U, 2U, 2U));
+ }
+};
+
+// This dataset represents the (m+k)-D updates tensor, (k+1)-d indices tensor and (m+n)-d output tensor cases
+class SmallScatterBatchedDataset final : public ScatterDataset
+{
+public:
+ SmallScatterBatchedDataset()
+ {
+ // NOTE: Config is src, updates, indices, output.
+ // NOTE: Updates/Indices tensors are now batched.
+ // NOTE: indices.shape.x = (updates_batched) ? (src.num_dimensions - updates.num_dimensions) + 2 : (src.num_dimensions - updates.num_dimensions) + 1
+ // k is the number of batch dimensions
+ // k = 2
+ add_config(TensorShape(6U, 5U), TensorShape(6U, 2U, 2U), TensorShape(1U, 2U, 2U), TensorShape(6U, 5U));
+ add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(5U, 5U, 6U, 2U), TensorShape(3U, 6U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U));
+
+ // k = 3
+ add_config(TensorShape(6U, 5U), TensorShape(6U, 2U, 2U, 2U), TensorShape(1U, 2U, 2U, 2U), TensorShape(6U, 5U));
+ add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(5U, 5U, 3U, 6U, 2U), TensorShape(3U, 3U, 6U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U));
+
+ // k = 4
+ add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(5U, 6U, 2U, 3U, 2U), TensorShape(4U, 6U, 2U, 3U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U));
+
+ // k = 5
+ add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(5U, 3U, 4U, 3U, 2U, 2U), TensorShape(4U, 3U, 4U, 3U, 2U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U));
+ }
+};
+
+class SmallScatterScalarDataset final : public ScatterDataset
+{
+public:
+ // batched scalar case
+ SmallScatterScalarDataset()
+ {
+ add_config(TensorShape(6U, 5U), TensorShape(6U), TensorShape(2U, 6U), TensorShape(6U, 5U));
+ add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U));
+ add_config(TensorShape(3U, 3U, 6U, 5U), TensorShape(6U, 6U), TensorShape(4U, 6U, 6U), TensorShape(3U, 3U, 6U, 5U));
+ }
+};
+
+// This dataset is for data types that does not require full testing. It contains selected tests from the above.
+class SmallScatterMixedDataset final : public ScatterDataset
+{
+public:
+ SmallScatterMixedDataset()
+ {
+ add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U));
+ add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U));
+ add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U));
+ add_config(TensorShape(35U, 4U, 3U, 2U, 2U), TensorShape(35U, 4U), TensorShape(4U, 4U), TensorShape(35U, 4U, 3U, 2U, 2U));
+ add_config(TensorShape(11U, 3U, 3U, 2U, 4U), TensorShape(11U, 3U, 3U, 4U), TensorShape(2U, 4U), TensorShape(11U, 3U, 3U, 2U, 4U));
+ add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 2U, 2U), TensorShape(2U, 2U, 2U), TensorShape(6U, 5U, 2U));
}
};
} // namespace datasets
diff --git a/tests/datasets/SmallGEMMDataset.h b/tests/datasets/SmallGEMMDataset.h
index c12f57b266..99c7abbf64 100644
--- a/tests/datasets/SmallGEMMDataset.h
+++ b/tests/datasets/SmallGEMMDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_SMALL_GEMM_DATASET
-#define ARM_COMPUTE_TEST_SMALL_GEMM_DATASET
+#ifndef ACL_TESTS_DATASETS_SMALLGEMMDATASET_H
+#define ACL_TESTS_DATASETS_SMALLGEMMDATASET_H
#include "tests/datasets/GEMMDataset.h"
@@ -97,7 +97,18 @@ public:
}
};
+class SmallAccumulateGEMMDataset final : public GEMMDataset
+{
+public:
+ SmallAccumulateGEMMDataset()
+ {
+ add_config(TensorShape(8U, 2U), TensorShape(16U, 8U), TensorShape(16U, 2U), TensorShape(16U, 2U), 1.0f, 0.0f);
+ add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U), TensorShape(23U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(21U, 13U), TensorShape(33U, 21U), TensorShape(33U, 13U), TensorShape(33U, 13U), 1.0f, 0.0f);
+ }
+};
+
} // namespace datasets
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_SMALL_GEMM_DATASET */
+#endif // ACL_TESTS_DATASETS_SMALLGEMMDATASET_H