diff options
Diffstat (limited to 'tests/validation/CL/GEMMReshapeLHSMatrix.cpp')
-rw-r--r-- | tests/validation/CL/GEMMReshapeLHSMatrix.cpp | 50 |
1 files changed, 44 insertions, 6 deletions
diff --git a/tests/validation/CL/GEMMReshapeLHSMatrix.cpp b/tests/validation/CL/GEMMReshapeLHSMatrix.cpp index d252f87958..0dd9b811f6 100644 --- a/tests/validation/CL/GEMMReshapeLHSMatrix.cpp +++ b/tests/validation/CL/GEMMReshapeLHSMatrix.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 ARM Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,11 +21,11 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/CL/CLTensorAllocator.h" +#include "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h" #include "tests/CL/CLAccessor.h" #include "tests/CL/Helper.h" #include "tests/PaddingCalculator.h" @@ -43,9 +43,10 @@ namespace test namespace validation { using namespace arm_compute::misc::shape_calculator; +using namespace arm_compute::opencl::kernels; // Initialize the output tensor with zero and fill the border with zero -using CLGEMMReshapeLHSMatrix = CLSynthetizeFunctionInitOutputWithZeroAndWithZeroConstantBorder<CLGEMMReshapeLHSMatrixKernel, 16>; +using CLGEMMReshapeLHSMatrix = CLSynthetizeOperatorInitOutputWithZeroAndWithZeroConstantBorder<ClGemmReshapeLhsMatrixKernel, 16>; template <typename T> using CLGEMMReshapeLHSMatrixFixture = GEMMReshapeLHSMatrixValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeLHSMatrix, T, false>; @@ -65,8 +66,10 @@ const auto b_values = framework::dataset::make("batchsize", 1, 3); /** M0 values to test */ const auto m0_values_s32 = framework::dataset::make("M0", { 2, 3 }); -const auto m0_values_s16 = framework::dataset::make("M0", { 4, 5 }); -const auto m0_values_s8 = framework::dataset::make("M0", { 6, 7, 8 }); +const auto m0_values_s16 = framework::dataset::make("M0", { 4 }); +const auto m0_values_s16_nt = framework::dataset::make("M0", { 5 }); +const auto m0_values_s8_nt = framework::dataset::make("M0", { 6,7 }); +const auto m0_values_s8 = framework::dataset::make("M0", { 8 }); /** K0 values to test */ const auto k0_values_s32 = framework::dataset::make("K0", { 2, 3 }); @@ -81,10 +84,12 @@ const auto i_values = framework::dataset::make("interleave", { true, false }); /** Transpose values to test */ const auto t_values = framework::dataset::make("transpose", { true, false }); + } // namespace TEST_SUITE(CL) TEST_SUITE(GEMMReshapeLHSMatrix) + FIXTURE_DATA_TEST_CASE(S32, CLGEMMReshapeLHSMatrixFixture<int>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), b_values), @@ -98,6 +103,7 @@ FIXTURE_DATA_TEST_CASE(S32, CLGEMMReshapeLHSMatrixFixture<int>, framework::Datas // Validate output validate(CLAccessor(_target), _reference); } + FIXTURE_DATA_TEST_CASE(S16, CLGEMMReshapeLHSMatrixFixture<short>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), b_values), @@ -111,6 +117,7 @@ FIXTURE_DATA_TEST_CASE(S16, CLGEMMReshapeLHSMatrixFixture<short>, framework::Dat // Validate output validate(CLAccessor(_target), _reference); } + FIXTURE_DATA_TEST_CASE(S8, CLGEMMReshapeLHSMatrixFixture<char>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), b_values), @@ -125,6 +132,37 @@ FIXTURE_DATA_TEST_CASE(S8, CLGEMMReshapeLHSMatrixFixture<char>, framework::Datas validate(CLAccessor(_target), _reference); } +TEST_SUITE(NotTransposed) +FIXTURE_DATA_TEST_CASE(S16, CLGEMMReshapeLHSMatrixFixture<short>, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), + b_values), + framework::dataset::make("DataType", DataType::S16)), + m0_values_s16_nt), + k0_values_s16), + v0_values), + i_values), + framework::dataset::make("transpose", { false }))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} + +FIXTURE_DATA_TEST_CASE(S8, CLGEMMReshapeLHSMatrixFixture<char>, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape2DShapes(), + b_values), + framework::dataset::make("DataType", DataType::S8)), + m0_values_s8_nt), + k0_values_s8), + v0_values), + i_values), + framework::dataset::make("transpose", { false }))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} + +TEST_SUITE_END() + TEST_SUITE(ReinterpretInputAs3D) FIXTURE_DATA_TEST_CASE(S32, CLGEMMReshapeLHSMatrix3DFixture<int>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallGEMMReshape3DShapes(), @@ -172,4 +210,4 @@ TEST_SUITE_END() // GEMMReshapeLHSMatrix TEST_SUITE_END() // CL } // namespace validation } // namespace test -} // namespace arm_compute
\ No newline at end of file +} // namespace arm_compute |