diff options
Diffstat (limited to 'tests/datasets')
-rw-r--r-- | tests/datasets/LargeMatMulDataset.h (renamed from tests/datasets/LargeBatchMatMulDataset.h) | 16 | ||||
-rw-r--r-- | tests/datasets/MatMulDataset.h (renamed from tests/datasets/BatchMatMulDataset.h) | 14 | ||||
-rw-r--r-- | tests/datasets/SmallMatMulDataset.h (renamed from tests/datasets/SmallBatchMatMulDataset.h) | 23 |
3 files changed, 32 insertions, 21 deletions
diff --git a/tests/datasets/LargeBatchMatMulDataset.h b/tests/datasets/LargeMatMulDataset.h index 0d8ff913cf..cbc97d5e4a 100644 --- a/tests/datasets/LargeBatchMatMulDataset.h +++ b/tests/datasets/LargeMatMulDataset.h @@ -21,12 +21,12 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ACL_TESTS_DATASETS_LARGEBATCHMATMULDATASET -#define ACL_TESTS_DATASETS_LARGEBATCHMATMULDATASET +#ifndef ACL_TESTS_DATASETS_LARGEMATMULDATASET +#define ACL_TESTS_DATASETS_LARGEMATMULDATASET #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" -#include "tests/datasets/BatchMatMulDataset.h" +#include "tests/datasets/MatMulDataset.h" namespace arm_compute { @@ -34,10 +34,10 @@ namespace test { namespace datasets { -class LargeBatchMatMulDataset final : public BatchMatMulDataset +class LargeMatMulDataset final : public MatMulDataset { public: - LargeBatchMatMulDataset() + LargeMatMulDataset() { add_config(TensorShape(21U, 13U, 3U, 2U), TensorShape(33U, 21U, 3U, 2U), TensorShape(33U, 13U, 3U, 2U)); add_config(TensorShape(38U, 12U, 1U, 5U), TensorShape(21U, 38U, 1U, 5U), TensorShape(21U, 12U, 1U, 5U)); @@ -45,10 +45,10 @@ public: } }; -class HighDimensionalBatchMatMulDataset final : public BatchMatMulDataset +class HighDimensionalMatMulDataset final : public MatMulDataset { public: - HighDimensionalBatchMatMulDataset() + HighDimensionalMatMulDataset() { add_config(TensorShape(5U, 5U, 2U, 2U, 2U, 2U), TensorShape(5U, 5U, 2U, 2U, 2U, 2U), TensorShape(5U, 5U, 2U, 2U, 2U, 2U)); // 6D tensor } @@ -57,4 +57,4 @@ public: } // namespace datasets } // namespace test } // namespace arm_compute -#endif /* ACL_TESTS_DATASETS_LARGEBATCHMATMULDATASET */ +#endif /* ACL_TESTS_DATASETS_LARGEMATMULDATASET */ diff --git a/tests/datasets/BatchMatMulDataset.h b/tests/datasets/MatMulDataset.h index dad7cc0af4..9c1c5fb05d 100644 --- a/tests/datasets/BatchMatMulDataset.h +++ b/tests/datasets/MatMulDataset.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef TESTS_DATASETS_BATCHMATMULDATASET -#define TESTS_DATASETS_BATCHMATMULDATASET +#ifndef ACL_TESTS_DATASETS_MATMULDATASET +#define ACL_TESTS_DATASETS_MATMULDATASET #include "arm_compute/core/TensorShape.h" #include "utils/TypePrinter.h" @@ -33,7 +33,7 @@ namespace test { namespace datasets { -class BatchMatMulDataset +class MatMulDataset { public: using type = std::tuple<TensorShape, TensorShape, TensorShape>; @@ -58,7 +58,7 @@ public: return description.str(); } - BatchMatMulDataset::type operator*() const + MatMulDataset::type operator*() const { return std::make_tuple(*_a_it, *_b_it, *_dst_it); } @@ -96,8 +96,8 @@ public: } protected: - BatchMatMulDataset() = default; - BatchMatMulDataset(BatchMatMulDataset &&) = default; + MatMulDataset() = default; + MatMulDataset(MatMulDataset &&) = default; private: std::vector<TensorShape> _a_shapes{}; @@ -107,4 +107,4 @@ private: } // namespace datasets } // namespace test } // namespace arm_compute -#endif /* TESTS_DATASETS_BATCHMATMULDATASET */ +#endif /* ACL_TESTS_DATASETS_MATMULDATASET */ diff --git a/tests/datasets/SmallBatchMatMulDataset.h b/tests/datasets/SmallMatMulDataset.h index cfe76bea6d..ae92b9abf5 100644 --- a/tests/datasets/SmallBatchMatMulDataset.h +++ b/tests/datasets/SmallMatMulDataset.h @@ -21,12 +21,12 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ACL_TESTS_DATASETS_SMALLBATCHMATMULDATASET -#define ACL_TESTS_DATASETS_SMALLBATCHMATMULDATASET +#ifndef ACL_TESTS_DATASETS_SMALLMATMULDATASET +#define ACL_TESTS_DATASETS_SMALLMATMULDATASET #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" -#include "tests/datasets/BatchMatMulDataset.h" +#include "tests/datasets/MatMulDataset.h" namespace arm_compute { @@ -34,10 +34,10 @@ namespace test { namespace datasets { -class SmallBatchMatMulDataset final : public BatchMatMulDataset +class SmallMatMulDataset final : public MatMulDataset { public: - SmallBatchMatMulDataset() + SmallMatMulDataset() { add_config(TensorShape(3U, 4U, 2U, 2U), TensorShape(2U, 3U, 2U, 2U), TensorShape(2U, 4U, 2U, 2U)); add_config(TensorShape(9U, 6U), TensorShape(5U, 9U), TensorShape(5U, 6U)); @@ -46,7 +46,18 @@ public: add_config(TensorShape(32U, 2U), TensorShape(17U, 32U), TensorShape(17U, 2U)); } }; + +class TinyMatMulDataset final : public MatMulDataset +{ +public: + TinyMatMulDataset() + { + add_config(TensorShape(1U), TensorShape(1U), TensorShape(1U)); + add_config(TensorShape(2U, 2U), TensorShape(2U, 2U), TensorShape(2U, 2U)); + } +}; + } // namespace datasets } // namespace test } // namespace arm_compute -#endif /* ACL_TESTS_DATASETS_SMALLBATCHMATMULDATASET */ +#endif /* ACL_TESTS_DATASETS_SMALLMATMULDATASET */ |