diff options
Diffstat (limited to 'src/runtime/CL/functions/CLLSTMLayerQuantized.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLLSTMLayerQuantized.cpp | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/src/runtime/CL/functions/CLLSTMLayerQuantized.cpp b/src/runtime/CL/functions/CLLSTMLayerQuantized.cpp index a44dcd2e24..589523a3c3 100644 --- a/src/runtime/CL/functions/CLLSTMLayerQuantized.cpp +++ b/src/runtime/CL/functions/CLLSTMLayerQuantized.cpp @@ -28,11 +28,6 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "src/core/CL/kernels/CLFillBorderKernel.h" -#include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.h" -#include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h" -#include "src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h" -#include "src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.h" -#include "src/core/CL/kernels/CLGEMMLowpReductionKernel.h" #include "src/core/helpers/AutoConfiguration.h" #include <memory> @@ -179,7 +174,13 @@ void CLLSTMLayerQuantized::configure(const CLCompileContext &compile_context, co quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift); _memory_group.manage(&_output_lowp); - _output_stage.configure(compile_context, &_output_highp, &_bias, &_output_lowp, output_multiplier, output_shift); + + GEMMLowpOutputStageInfo info{}; + info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT; + info.gemmlowp_multiplier = output_multiplier; + info.gemmlowp_shift = output_shift; + info.output_data_type = DataType::QSYMM16; + _output_stage.configure(compile_context, &_output_highp, &_bias, &_output_lowp, info); _output_highp.allocator()->allocate(); _bias.allocator()->allocate(); @@ -386,7 +387,12 @@ Status CLLSTMLayerQuantized::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift)); // _output_stage - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint::validate(&output_highp, &bias_concatenated, &output_lowp)); + GEMMLowpOutputStageInfo info{}; + info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT; + info.gemmlowp_multiplier = output_multiplier; + info.gemmlowp_shift = output_shift; + info.output_data_type = DataType::QSYMM16; + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&output_highp, &bias_concatenated, &output_lowp, info)); TensorInfo input_gate_input; TensorInfo forget_gate_input; |