diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2021-06-25 12:13:49 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2021-06-29 16:26:41 +0000 |
commit | 4a578b923ed000c67fe0bc1433f945aea634ca9c (patch) | |
tree | b7bb041d2e7bfb4b909199f1b889585d237c665d /tests/validation | |
parent | 53832b2bcce44c71fe31a618a81765294df55750 (diff) | |
download | ComputeLibrary-4a578b923ed000c67fe0bc1433f945aea634ca9c.tar.gz |
Port the ClGemmLowp kernels to the new API
Ported kernels:
- CLGEMMLowpMatrixMultiplyNativeKernel
- CLGEMMLowpMatrixMultiplyReshapedKernel
- CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel
- CLGEMMLowpOffsetContributionKernel
- CLGEMMLowpOffsetContributionOutputStageKernel
- CLGEMMLowpQuantizeDownInt32ScaleByFixedPointKernel
- CLGEMMLowpQuantizeDownInt32ScaleByFloatKernel
- CLGEMMLowpQuantizeDownInt32ScaleKernel
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I9d5a744d6a2dd2f2726fdfb291bad000b6970de2
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5870
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/CL/GEMMLowp.cpp | 134 | ||||
-rw-r--r-- | tests/validation/CL/GEMMLowpMatrixMultiplyNative.cpp | 6 | ||||
-rw-r--r-- | tests/validation/CL/GEMMLowpMatrixMultiplyReshaped.cpp | 4 | ||||
-rw-r--r-- | tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp | 6 | ||||
-rw-r--r-- | tests/validation/fixtures/GEMMLowpFixture.h | 30 |
5 files changed, 26 insertions, 154 deletions
diff --git a/tests/validation/CL/GEMMLowp.cpp b/tests/validation/CL/GEMMLowp.cpp index 1c7446f653..52adb94c83 100644 --- a/tests/validation/CL/GEMMLowp.cpp +++ b/tests/validation/CL/GEMMLowp.cpp @@ -213,140 +213,6 @@ TEST_SUITE_END() // BoundedReLu TEST_SUITE_END() // QASYMM8_SIGNED TEST_SUITE_END() // QuantizeDownInt32Scale -TEST_SUITE(QuantizeDownInt32ScaleByFixedPoint) - -TEST_SUITE(QASYMM8) - -const auto quantize_down_int32_to_uint8_scale_by_fixedpoint_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, 254601602) * framework::dataset::make("result_shift", 1, - 2) - * framework::dataset::make("result_offset_after_shift", 2, 3) * framework::dataset::make("min", 0) * framework::dataset::make("max", 255) * framework::dataset::make("addBias", { false, true }); - -const auto quantize_down_int32_to_uint8_scale_by_fixedpoint_relu_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, 254601602) * framework::dataset::make("result_shift", 1, - 2) - * framework::dataset::make("result_offset_after_shift", 2, 3) * framework::dataset::make("min", 0, 2) * framework::dataset::make("max", 171, 174) * framework::dataset::make("addBias", { false, true }); -using CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointFixture = - GEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointValidationFixture<CLTensor, CLAccessor, CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint>; - -FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), - quantize_down_int32_to_uint8_scale_by_fixedpoint_cases)) -{ - // Validate output - validate(CLAccessor(_target), _reference); -} - -FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointFixture, framework::DatasetMode::NIGHTLY, combine(datasets::LargeShapes(), - quantize_down_int32_to_uint8_scale_by_fixedpoint_cases)) -{ - // Validate output - validate(CLAccessor(_target), _reference); -} - -TEST_SUITE(BoundedReLu) -FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), - quantize_down_int32_to_uint8_scale_by_fixedpoint_relu_cases)) -{ - // Validate output - validate(CLAccessor(_target), _reference); -} - -FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointFixture, framework::DatasetMode::NIGHTLY, combine(datasets::LargeShapes(), - quantize_down_int32_to_uint8_scale_by_fixedpoint_relu_cases)) -{ - // Validate output - validate(CLAccessor(_target), _reference); -} -TEST_SUITE_END() // BoundedReLu -TEST_SUITE_END() // QASYMM8 -TEST_SUITE(QASYMM8_SIGNED) -const auto quantize_down_int32_to_int8_scale_by_fixedpoint_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, 254601602) * framework::dataset::make("result_shift", 1, 2) - * framework::dataset::make("result_offset_after_shift", 2, 3) * framework::dataset::make("min", -128) * framework::dataset::make("max", 127) * framework::dataset::make("addBias", { false, true }); - -const auto quantize_down_int32_to_int8_scale_by_fixedpoint_relu_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, 254601602) * framework::dataset::make("result_shift", 1, 2) - * framework::dataset::make("result_offset_after_shift", 2, 3) * framework::dataset::make("min", -128, -126) * framework::dataset::make("max", 110, 112) * framework::dataset::make("addBias", { false, true }); -using CLGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointFixture = - GEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointValidationFixture<CLTensor, CLAccessor, CLGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPoint>; - -FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), - quantize_down_int32_to_int8_scale_by_fixedpoint_cases)) -{ - // Validate output - validate(CLAccessor(_target), _reference); -} - -TEST_SUITE(BoundedReLu) -FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), - quantize_down_int32_to_int8_scale_by_fixedpoint_relu_cases)) -{ - // Validate output - validate(CLAccessor(_target), _reference); -} - -TEST_SUITE_END() // BoundedReLu -TEST_SUITE_END() // QASYMM8_SIGNED -TEST_SUITE(QSYMM16) - -const auto quantize_down_int32_to_int16_scale_by_fixedpoint_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, 254601602) * framework::dataset::make("result_shift", 1, - 2) - * framework::dataset::make("min", -32768) * framework::dataset::make("max", 32767) * framework::dataset::make("addBias", { false, true }); - -const auto quantize_down_int32_to_int16_scale_by_fixedpoint_relu_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, 254601602) * framework::dataset::make("result_shift", 1, - 2) - * framework::dataset::make("min", -2, 0) * framework::dataset::make("max", 1, 3) * framework::dataset::make("addBias", { false, true }); - -const auto quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_cases = framework::dataset::make("result_fixedpoint_multiplier", 1073741823, - 1073741825) - * framework::dataset::make("result_shift", -3, - -2) - * framework::dataset::make("min", -32768) * framework::dataset::make("max", 32767) * framework::dataset::make("addBias", { false, true }); - -const auto quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_relu_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, - 254601602) - * framework::dataset::make("result_shift", -3, - -1) - * framework::dataset::make("min", -2, 0) * framework::dataset::make("max", 1, 3) * framework::dataset::make("addBias", { false, true }); - -using CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture = - GEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointValidationFixture<CLTensor, CLAccessor, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint>; - -TEST_SUITE(NoRelu) -TEST_SUITE(MultSmallerEq1) -FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), - quantize_down_int32_to_int16_scale_by_fixedpoint_cases)) -{ - // Validate output - validate(CLAccessor(_target), _reference); -} -TEST_SUITE_END() // MultSmallerEq1 -TEST_SUITE(MultGreater1) -FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), - quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_cases)) -{ - // Validate output - validate(CLAccessor(_target), _reference); -} -TEST_SUITE_END() // MultGreater1 -TEST_SUITE_END() // NoRelu -TEST_SUITE(BoundedReLu) -TEST_SUITE(MultSmallerEq1) -FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), - quantize_down_int32_to_int16_scale_by_fixedpoint_relu_cases)) -{ - // Validate output - validate(CLAccessor(_target), _reference); -} -TEST_SUITE_END() // MultSmallerEq1 -TEST_SUITE(MultGreater1) -FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), - quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_relu_cases)) -{ - // Validate output - validate(CLAccessor(_target), _reference); -} -TEST_SUITE_END() // MultGreater1 -TEST_SUITE_END() // BoundedReLu -TEST_SUITE_END() // QSYMM16 -TEST_SUITE_END() // QuantizeDownInt32ScaleByFixedPoint - TEST_SUITE(QuantizeDownInt32ScaleByFloat) TEST_SUITE(QASYMM8) diff --git a/tests/validation/CL/GEMMLowpMatrixMultiplyNative.cpp b/tests/validation/CL/GEMMLowpMatrixMultiplyNative.cpp index 1057af95f2..d733a00296 100644 --- a/tests/validation/CL/GEMMLowpMatrixMultiplyNative.cpp +++ b/tests/validation/CL/GEMMLowpMatrixMultiplyNative.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,7 @@ */ #include "arm_compute/core/Types.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.h" +#include "src/core/gpu/cl/kernels/ClGemmLowpMatrixMultiplyNativeKernel.h" #include "tests/CL/CLAccessor.h" #include "tests/CL/Helper.h" #include "tests/framework/Asserts.h" @@ -41,7 +41,7 @@ namespace validation using namespace arm_compute::misc::shape_calculator; // Create function for CLGEMMMatrixMultiplyNativeKernel -using CLGEMMLowpMatrixMultiplyNative = CLSynthetizeFunction<CLGEMMLowpMatrixMultiplyNativeKernel>; +using CLGEMMLowpMatrixMultiplyNative = CLSynthetizeOperator<opencl::kernels::ClGemmLowpMatrixMultiplyNativeKernel>; // Fixture for CLGEMMLowpMatrixMultiplyNative using CLGEMMLowpMatrixMultiplyNativeFixture = GEMMLowpMatrixMultiplyNativeValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyNative>; diff --git a/tests/validation/CL/GEMMLowpMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMLowpMatrixMultiplyReshaped.cpp index 68a7d055ad..3baa39bffc 100644 --- a/tests/validation/CL/GEMMLowpMatrixMultiplyReshaped.cpp +++ b/tests/validation/CL/GEMMLowpMatrixMultiplyReshaped.cpp @@ -23,7 +23,7 @@ */ #include "arm_compute/core/Types.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.h" +#include "src/core/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedKernel.h" #include "src/core/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h" #include "src/core/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h" #include "tests/CL/CLAccessor.h" @@ -49,7 +49,7 @@ using CLGEMMReshapeLHSMatrix = CLSynthetizeOperator<opencl::kernels::ClGemmResha using CLGEMMReshapeRHSMatrix = CLSynthetizeOperator<opencl::kernels::ClGemmReshapeRhsMatrixKernel>; // Create function for CLGEMMLowpMatrixMultiplyReshapedKernel -using CLGEMMLowpMatrixMultiplyReshaped = CLSynthetizeFunction<CLGEMMLowpMatrixMultiplyReshapedKernel>; +using CLGEMMLowpMatrixMultiplyReshaped = CLSynthetizeOperator<opencl::kernels::ClGemmLowpMatrixMultiplyReshapedKernel>; // Fixture for CLGEMMLowpMatrixMultiplyReshaped using CLGEMMLowpMatrixMultiplyReshapedFixture = GEMMLowpMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMLowpMatrixMultiplyReshaped>; diff --git a/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp index 43b86b51e8..1283713c4d 100644 --- a/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp +++ b/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp @@ -25,7 +25,7 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/CL/CLTensorAllocator.h" -#include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h" +#include "src/core/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel.h" #include "src/core/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h" #include "tests/CL/CLAccessor.h" #include "tests/CL/Helper.h" @@ -49,7 +49,7 @@ using namespace arm_compute::misc::shape_calculator; using CLGEMMReshapeRHSMatrix = CLSynthetizeOperator<opencl::kernels::ClGemmReshapeRhsMatrixKernel>; // Create function for CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel -using CLGEMMLowpMatrixMultiplyReshapedOnlyRHS = CLSynthetizeFunction<CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel>; +using CLGEMMLowpMatrixMultiplyReshapedOnlyRHS = CLSynthetizeOperator<opencl::kernels::ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel>; // Fixture for CLGEMMLowpMatrixMultiplyReshapedOnlyRHS using CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFixture = GEMMLowpMatrixMultiplyReshapedOnlyRHSValidationFixture<CLTensor, CLAccessor, CLGEMMReshapeRHSMatrix, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS>; @@ -157,7 +157,7 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned // Create and configure function CLGEMMLowpMatrixMultiplyReshapedOnlyRHS gemm; - gemm.configure(&lhs, &rhs_reshaped, &dst, gemm_info); + gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info); } } // namespace diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h index ab9d35de0f..5e2154592e 100644 --- a/tests/validation/fixtures/GEMMLowpFixture.h +++ b/tests/validation/fixtures/GEMMLowpFixture.h @@ -959,7 +959,7 @@ protected: GEMMFunctionType gemm; reshape_lhs.configure(lhs.info(), lhs_reshaped.info(), lhs_info); reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info); - gemm.configure(&lhs_reshaped, &rhs_reshaped, &dst, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K)); + gemm.configure(lhs_reshaped.info(), rhs_reshaped.info(), dst.info(), lhs_info, rhs_info, GEMMReshapeInfo(M, N, K)); ARM_COMPUTE_ASSERT(lhs.info()->is_resizable()); ARM_COMPUTE_ASSERT(rhs.info()->is_resizable()); @@ -988,7 +988,8 @@ protected: reshape_lhs.run(reshape_lhs_pack); ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } }; reshape_rhs.run(reshape_rhs_pack); - gemm.run(); + ITensorPack gemm_pack({ { ACL_SRC_0, &lhs_reshaped }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } }); + gemm.run(gemm_pack); return dst; } @@ -1113,7 +1114,7 @@ protected: GEMMFunctionType gemm; reshape_lhs.configure(lhs.info(), lhs_reshaped.info(), lhs_info); reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info); - gemm.configure(&lhs_reshaped, &rhs_reshaped, &dst, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h)); + gemm.configure(lhs_reshaped.info(), rhs_reshaped.info(), dst.info(), lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h)); ARM_COMPUTE_ASSERT(lhs.info()->is_resizable()); ARM_COMPUTE_ASSERT(rhs.info()->is_resizable()); @@ -1142,7 +1143,8 @@ protected: reshape_lhs.run(reshape_lhs_pack); ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } }; reshape_rhs.run(reshape_rhs_pack); - gemm.run(); + ITensorPack gemm_pack({ { ACL_SRC_0, &lhs_reshaped }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } }); + gemm.run(gemm_pack); return dst; } @@ -1266,7 +1268,7 @@ protected: ReshapeRHSOperatorType reshape_rhs; GEMMFunctionType gemm; reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info); - gemm.configure(&lhs, &rhs_reshaped, &dst, gemm_info); + gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info); ARM_COMPUTE_ASSERT(lhs.info()->is_resizable()); ARM_COMPUTE_ASSERT(rhs.info()->is_resizable()); @@ -1291,7 +1293,8 @@ protected: // Compute GEMM ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } }; reshape_rhs.run(reshape_rhs_pack); - gemm.run(); + ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } }); + gemm.run(gemm_pack); return dst; } @@ -1412,7 +1415,7 @@ protected: ReshapeRHSOperatorType reshape_rhs; GEMMFunctionType gemm; reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info); - gemm.configure(&lhs, &rhs_reshaped, &dst, gemm_info); + gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info); ARM_COMPUTE_ASSERT(lhs.info()->is_resizable()); ARM_COMPUTE_ASSERT(rhs.info()->is_resizable()); @@ -1437,7 +1440,8 @@ protected: // Compute GEMM ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } }; reshape_rhs.run(reshape_rhs_pack); - gemm.run(); + ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } }); + gemm.run(gemm_pack); return dst; } @@ -1527,7 +1531,7 @@ protected: // Create and configure function GEMMFunctionType gemm; - gemm.configure(&lhs, &rhs, &dst, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K)); + gemm.configure(lhs.info(), rhs.info(), dst.info(), lhs_info, rhs_info, GEMMReshapeInfo(M, N, K)); ARM_COMPUTE_ASSERT(lhs.info()->is_resizable()); ARM_COMPUTE_ASSERT(rhs.info()->is_resizable()); @@ -1548,7 +1552,8 @@ protected: fill(AccessorType(rhs), 1); // Compute GEMM - gemm.run(); + ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs }, { ACL_DST, &dst } }); + gemm.run(gemm_pack); return dst; } @@ -1624,7 +1629,7 @@ protected: // Create and configure function GEMMFunctionType gemm; - gemm.configure(&lhs, &rhs, &dst, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h)); + gemm.configure(lhs.info(), rhs.info(), dst.info(), lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h)); ARM_COMPUTE_ASSERT(lhs.info()->is_resizable()); ARM_COMPUTE_ASSERT(rhs.info()->is_resizable()); @@ -1645,7 +1650,8 @@ protected: fill(AccessorType(rhs), 1); // Compute GEMM - gemm.run(); + ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs }, { ACL_DST, &dst } }); + gemm.run(gemm_pack); return dst; } |