diff options
Diffstat (limited to 'src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp | 75 |
1 files changed, 25 insertions, 50 deletions
diff --git a/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp b/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp index e230e8f2e6..94d4c33fa2 100644 --- a/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp +++ b/src/runtime/CL/functions/CLGEMMLowpOutputStage.cpp @@ -23,20 +23,32 @@ */ #include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h" +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" #include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/KernelDescriptors.h" +#include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" -#include "arm_compute/runtime/CL/CLScheduler.h" -#include "src/core/gpu/cl/kernels/ClGemmLowpQuantizeDownInt32ScaleByFixedPointKernel.h" -#include "src/core/gpu/cl/kernels/ClGemmLowpQuantizeDownInt32ScaleByFloatKernel.h" -#include "src/core/gpu/cl/kernels/ClGemmLowpQuantizeDownInt32ScaleKernel.h" +#include "src/core/CL/ICLKernel.h" +#include "src/runtime/gpu/cl/operators/ClGemmLowpOutputStage.h" #include <algorithm> namespace arm_compute { +struct CLGEMMLowpOutputStage::Impl +{ + const ICLTensor *src{ nullptr }; + const ICLTensor *bias{ nullptr }; + ICLTensor *dst{ nullptr }; + std::unique_ptr<opencl::ClGemmLowpOutputStage> op{ nullptr }; + ITensorPack run_pack{}; +}; + CLGEMMLowpOutputStage::CLGEMMLowpOutputStage() - : _kernel(nullptr), _input(nullptr), _bias(nullptr), _output(nullptr) + : _impl(std::make_unique<Impl>()) { } CLGEMMLowpOutputStage::CLGEMMLowpOutputStage(CLGEMMLowpOutputStage &&) = default; @@ -52,59 +64,22 @@ void CLGEMMLowpOutputStage::configure(const CLCompileContext &compile_context, c { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); - _input = input; - _bias = bias; - _output = output; + _impl->src = input; + _impl->bias = bias; + _impl->dst = output; - switch(info.type) - { - case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT: - { - auto k = std::make_unique<opencl::kernels::ClGemmLowpQuantizeDownInt32ScaleByFixedPointKernel>(); - k->configure(compile_context, input->info(), bias != nullptr ? bias->info() : nullptr, output->info(), &info); - _kernel = std::move(k); - break; - } - case GEMMLowpOutputStageType::QUANTIZE_DOWN: - { - auto k = std::make_unique<opencl::kernels::ClGemmLowpQuantizeDownInt32ScaleKernel>(); - k->configure(compile_context, input->info(), bias != nullptr ? bias->info() : nullptr, output->info(), &info); - _kernel = std::move(k); - break; - } - case GEMMLowpOutputStageType::QUANTIZE_DOWN_FLOAT: - { - auto k = std::make_unique<opencl::kernels::ClGemmLowpQuantizeDownInt32ScaleByFloatKernel>(); - k->configure(compile_context, input->info(), bias != nullptr ? bias->info() : nullptr, output->info(), &info); - _kernel = std::move(k); - break; - } - default: - ARM_COMPUTE_ERROR("Unsupported GEMMLowpOutputStage type."); - } + _impl->op = std::make_unique<opencl::ClGemmLowpOutputStage>(); + _impl->op->configure(compile_context, input->info(), bias != nullptr ? bias->info() : nullptr, output->info(), info); + _impl->run_pack = { { ACL_SRC, _impl->src }, { ACL_BIAS, _impl->bias }, { ACL_DST, _impl->dst } }; } Status CLGEMMLowpOutputStage::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo &info) { - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16); - - switch(info.type) - { - case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT: - return opencl::kernels::ClGemmLowpQuantizeDownInt32ScaleByFixedPointKernel::validate(input, bias, output, &info); - case GEMMLowpOutputStageType::QUANTIZE_DOWN: - return opencl::kernels::ClGemmLowpQuantizeDownInt32ScaleKernel::validate(input, bias, output, &info); - case GEMMLowpOutputStageType::QUANTIZE_DOWN_FLOAT: - return opencl::kernels::ClGemmLowpQuantizeDownInt32ScaleByFloatKernel::validate(input, bias, output, &info); - default: - return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported GEMMLowpOutputStage type."); - } + return opencl::ClGemmLowpOutputStage::validate(input, bias, output, info); } void CLGEMMLowpOutputStage::run() { - ITensorPack pack{ { ACL_SRC, _input }, { ACL_BIAS, _bias }, { ACL_DST, _output } }; - CLScheduler::get().enqueue_op(*_kernel, pack, true); + _impl->op->run(_impl->run_pack); } } // namespace arm_compute |