diff options
Diffstat (limited to 'examples/gemm_tuner/cl_gemmlowp_reshaped_rhs_only_fused_output_stage_fixedpoint.cpp')
-rw-r--r-- | examples/gemm_tuner/cl_gemmlowp_reshaped_rhs_only_fused_output_stage_fixedpoint.cpp | 30 |
1 files changed, 18 insertions, 12 deletions
diff --git a/examples/gemm_tuner/cl_gemmlowp_reshaped_rhs_only_fused_output_stage_fixedpoint.cpp b/examples/gemm_tuner/cl_gemmlowp_reshaped_rhs_only_fused_output_stage_fixedpoint.cpp index ca7b7a5f04..15c1b86c61 100644 --- a/examples/gemm_tuner/cl_gemmlowp_reshaped_rhs_only_fused_output_stage_fixedpoint.cpp +++ b/examples/gemm_tuner/cl_gemmlowp_reshaped_rhs_only_fused_output_stage_fixedpoint.cpp @@ -35,8 +35,8 @@ #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/CL/CLScheduler.h" #include "arm_compute/runtime/CL/CLTuner.h" -#include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h" -#include "src/core/CL/kernels/CLGEMMLowpReductionKernel.h" +#include "src/core/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel.h" +#include "src/core/gpu/cl/kernels/ClGemmLowpReductionKernel.h" #include "tests/CL/Helper.h" #include "utils/Utils.h" #include "utils/command_line/CommandLineOptions.h" @@ -47,6 +47,7 @@ using namespace arm_compute; using namespace utils; +using namespace arm_compute::opencl::kernels; using namespace arm_compute::misc::shape_calculator; using namespace gemm_tuner; @@ -146,8 +147,8 @@ GemmConfigs consume_gemm_configs(const GemmConfigOptions &options) } // namespace -using CLGEMMLowpMatrixMultiplyReshapedOnlyRHS = test::CLSynthetizeFunction<CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel>; -using CLGEMMLowpMatrixAReduction = test::CLSynthetizeFunction<CLGEMMLowpMatrixAReductionKernel>; +using ClGemmLowpMatrixMultiplyReshapedOnlyRhs = test::CLSynthetizeOperator<ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel>; +using ClGemmLowpMatrixAReduction = test::CLSynthetizeOperator<ClGemmLowpMatrixAReductionKernel>; class CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFusedOutputStageFixedpointExample : public Example { @@ -289,7 +290,7 @@ public: const TensorInfo info_vector_sum_row(compute_reductionB_shape(*lhs.info()), 1, DataType::S32); vector_sum_row.allocator()->init(info_vector_sum_row); - mtx_a_reduction = std::make_unique<CLGEMMLowpMatrixAReduction>(); + mtx_a_reduction = std::make_unique<ClGemmLowpMatrixAReduction>(); if(!mtx_a_reduction->validate(lhs.info(), vector_sum_row.info(), GEMMLowpReductionKernelInfo{})) { @@ -297,7 +298,7 @@ public: return false; } - mtx_a_reduction->configure(&lhs, &vector_sum_row, GEMMLowpReductionKernelInfo{}); + mtx_a_reduction->configure(lhs.info(), vector_sum_row.info(), GEMMLowpReductionKernelInfo{}); } // Initialize matrix B reduction kernel only if _a_offset is not equal to 0 if(gemm_info.a_offset != 0) @@ -311,12 +312,14 @@ public: if(!gemm.validate(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, gemm_info.a_offset == 0 ? nullptr : vector_sum_col.info(), gemm_info.b_offset == 0 ? nullptr : vector_sum_row.info(), bias.info(), dst_multipliers.info(), dst_shifts.info())) { - std::cerr << "Invalid arguments for CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel." << std::endl; + std::cerr << "Invalid arguments for ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel." << std::endl; return false; } // Configure function - gemm.configure(&lhs, &rhs_reshaped, &dst, gemm_info, gemm_info.a_offset == 0 ? nullptr : &vector_sum_col, gemm_info.b_offset == 0 ? nullptr : &vector_sum_row, &bias, &dst_multipliers, &dst_shifts); + gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, + gemm_info.a_offset == 0 ? nullptr : vector_sum_col.info(), gemm_info.b_offset == 0 ? nullptr : vector_sum_row.info(), + bias.info(), dst_multipliers.info(), dst_shifts.info()); // Allocate tensors lhs.allocator()->allocate(); @@ -335,9 +338,12 @@ public: { if(mtx_a_reduction != nullptr) { - mtx_a_reduction->run(); + ITensorPack red_pack({ { ACL_SRC, &lhs }, { ACL_DST, &dst } }); + mtx_a_reduction->run(red_pack); } - gemm.run(); + + ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs }, { ACL_BIAS, &bias }, { ACL_VEC_COL_SUM, &vector_sum_col }, { ACL_VEC_ROW_SUM, &vector_sum_row }, { ACL_SHIFTS, &dst_shifts }, { ACL_MULTIPLIERS, &dst_multipliers }, { ACL_DST, &dst } }); + gemm.run(gemm_pack); // Make sure all the OpenCL jobs are done executing: CLScheduler::get().sync(); @@ -358,8 +364,8 @@ private: CLTensor dst_multipliers{}; CLTensor dst_shifts{}; CLTuner tuner{}; - CLGEMMLowpMatrixMultiplyReshapedOnlyRHS gemm{}; - std::unique_ptr<CLGEMMLowpMatrixAReduction> mtx_a_reduction{ nullptr }; + ClGemmLowpMatrixMultiplyReshapedOnlyRhs gemm{}; + std::unique_ptr<ClGemmLowpMatrixAReduction> mtx_a_reduction{ nullptr }; }; /** Main test program for gemmlowp reshaped rhs only with fused output stage fixedpoint |