diff options
Diffstat (limited to 'src/runtime/cpu/operators/CpuGemmConvolution.cpp')
-rw-r--r-- | src/runtime/cpu/operators/CpuGemmConvolution.cpp | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/src/runtime/cpu/operators/CpuGemmConvolution.cpp b/src/runtime/cpu/operators/CpuGemmConvolution.cpp index fcdf8aa8f6..7defc13b20 100644 --- a/src/runtime/cpu/operators/CpuGemmConvolution.cpp +++ b/src/runtime/cpu/operators/CpuGemmConvolution.cpp @@ -58,15 +58,16 @@ CpuGemmConvolution::CpuGemmConvolution() } CpuGemmConvolution::~CpuGemmConvolution() = default; -void CpuGemmConvolution::configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act_info, int gemm_3d_depth) +void CpuGemmConvolution::configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act_info, + bool enable_fast_math, int gemm_3d_depth) { ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights); - ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, gemm_3d_depth, _skip_im2col)); + ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col)); // Create GEMMInfo structure const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, - false, GEMMLowpOutputStageInfo(), false, false, false, act_info); + false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info); // Supported activations in GEMM const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU, @@ -115,7 +116,7 @@ void CpuGemmConvolution::configure_mm(const ITensorInfo *src, const ITensorInfo quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info); _mm_gemmlowp = std::make_unique<CpuGemmLowpMatrixMultiplyCore>(); - _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, false, false, act_info)); + _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, enable_fast_math, false, act_info)); auto mm_mem_req = _mm_gemmlowp->workspace(); for(unsigned int cont = 0; cont < mm_mem_req.size(); ++cont) @@ -137,7 +138,7 @@ void CpuGemmConvolution::configure_mm(const ITensorInfo *src, const ITensorInfo } Status CpuGemmConvolution::validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, - const ActivationLayerInfo &act_info, int gemm_3d_depth, bool skip_im2col) + const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col) { const DataType data_type = src->data_type(); const bool is_quantized = is_data_type_quantized_asymmetric(data_type); @@ -146,7 +147,7 @@ Status CpuGemmConvolution::validate_mm(const ITensorInfo *src, const ITensorInfo // Create GEMMInfo structure const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, - false, GEMMLowpOutputStageInfo(), false, false, false, act_info); + false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info); if(is_quantized) { @@ -186,8 +187,8 @@ Status CpuGemmConvolution::validate_mm(const ITensorInfo *src, const ITensorInfo std::unique_ptr<ITensorInfo> weights_qa = weights->clone(); input_qa->set_quantization_info(QuantizationInfo(iqinfo.uniform().scale, -iqinfo.uniform().offset)); weights_qa->set_quantization_info(QuantizationInfo(wqinfo.uniform().scale, -wqinfo.uniform().offset)); - return CpuGemmLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, skip_im2col, false, output_info, false, false, false, - act_info)); + return CpuGemmLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, skip_im2col, false, output_info, + false, enable_fast_math, false, act_info)); } else { @@ -211,7 +212,7 @@ Status CpuGemmConvolution::validate_gemm3d(const ITensorInfo *input_info, const } void CpuGemmConvolution::configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const PadStrideInfo &conv_info, const WeightsInfo &weights_info, - const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups) + const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups) { ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst); ARM_COMPUTE_UNUSED(num_groups, weights_info); @@ -223,6 +224,7 @@ void CpuGemmConvolution::configure(const ITensorInfo *src, const ITensorInfo *we weights_info, dilation, act_info, + enable_fast_math, num_groups)); const DataType data_type = src->data_type(); @@ -324,7 +326,7 @@ void CpuGemmConvolution::configure(const ITensorInfo *src, const ITensorInfo *we // Configure GEMM // In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0; - configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, gemm_3d_depth); + configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth); if(!_skip_col2im && _data_layout == DataLayout::NCHW) { @@ -346,7 +348,7 @@ void CpuGemmConvolution::configure(const ITensorInfo *src, const ITensorInfo *we } Status CpuGemmConvolution::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info, - const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups) + const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst); ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!"); @@ -470,7 +472,7 @@ Status CpuGemmConvolution::validate(const ITensorInfo *src, const ITensorInfo *w } info_gemm.set_quantization_info(dst->quantization_info()).set_data_layout(src->data_layout()); gemm_output_to_use = &info_gemm; - ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, skip_col2im ? conv_h : 0, skip_im2col)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col)); // Validate Col2Im/ReshapeLayer if(!skip_col2im && (data_layout == DataLayout::NCHW)) |