diff options
Diffstat (limited to 'tests/validation/NEON/GEMM.cpp')
-rw-r--r-- | tests/validation/NEON/GEMM.cpp | 46 |
1 files changed, 45 insertions, 1 deletions
diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp index 57e8ce7ea9..1145d0b79e 100644 --- a/tests/validation/NEON/GEMM.cpp +++ b/tests/validation/NEON/GEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017, 2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -22,6 +22,7 @@ * SOFTWARE. */ #include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h" +#include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h" #include "arm_compute/core/Types.h" #include "arm_compute/runtime/NEON/functions/NEGEMM.h" #include "arm_compute/runtime/Tensor.h" @@ -37,6 +38,7 @@ #include "tests/validation/Validation.h" #include "tests/validation/fixtures/GEMMFixture.h" #include "tests/validation/fixtures/GEMMInterleave4x4Fixture.h" +#include "tests/validation/fixtures/GEMMTranspose1xWFixture.h" namespace arm_compute { @@ -61,11 +63,53 @@ const auto CNNDataTypes = framework::dataset::make("DataType", }); const auto data_interleave = framework::dataset::make("M", 8, 12) * framework::dataset::make("N", 8, 12); +const auto data_transpose = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14); + } // namespace TEST_SUITE(NEON) TEST_SUITE(GEMM) +TEST_SUITE(TRANSPOSE_1XW) +using NEGEMMTranspose1xW = NESynthetizeFunctionWithZeroConstantBorder<NEGEMMTranspose1xWKernel, 4>; +using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, NEGEMMTranspose1xW, float>; +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::F32)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() // FP32 + +TEST_SUITE(Quantized) +TEST_SUITE(QS8) +using NEGEMMTranspose1xW = NESynthetizeFunctionWithZeroConstantBorder<NEGEMMTranspose1xWKernel, 16>; +using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixedPointFixture<Tensor, Accessor, NEGEMMTranspose1xW, int8_t>; +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * + framework::dataset::make("DataType", DataType::QS8) + * framework::dataset::make("FractionalBits", 1, 7)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() + +TEST_SUITE(QS16) +using NEGEMMTranspose1xW = NESynthetizeFunctionWithZeroConstantBorder<NEGEMMTranspose1xWKernel, 8>; +using NEGEMMTranspose1xWFixture = GEMMTranspose1xWValidationFixedPointFixture<Tensor, Accessor, NEGEMMTranspose1xW, int16_t>; +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * + framework::dataset::make("DataType", DataType::QS16) + * framework::dataset::make("FractionalBits", 1, 14)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() + +TEST_SUITE_END() + +TEST_SUITE_END() // TRANSPOSE_1XW + TEST_SUITE(INTERLEAVE_4X4) using NEGEMMInterleave4x4 = NESynthetizeFunctionWithZeroConstantBorder<NEGEMMInterleave4x4Kernel, 4>; |