aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLGEMMDeconvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLGEMMDeconvolutionLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLGEMMDeconvolutionLayer.cpp205
1 files changed, 124 insertions, 81 deletions
diff --git a/src/runtime/CL/functions/CLGEMMDeconvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMDeconvolutionLayer.cpp
index d5d1b5f41e..7d40cf1829 100644
--- a/src/runtime/CL/functions/CLGEMMDeconvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMDeconvolutionLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,19 +24,15 @@
#include "arm_compute/runtime/CL/functions/CLGEMMDeconvolutionLayer.h"
#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/function_info/ActivationLayerInfo.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
+
+#include "src/common/utils/Log.h"
#include "src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.h"
#include "src/core/CL/kernels/CLFillBorderKernel.h"
-#include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.h"
-#include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h"
-#include "src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h"
-#include "src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.h"
-#include "src/core/CL/kernels/CLGEMMLowpReductionKernel.h"
-#include "src/core/CL/kernels/CLIm2ColKernel.h"
-#include "src/core/CL/kernels/CLWeightsReshapeKernel.h"
#include <tuple>
@@ -44,12 +40,13 @@ namespace arm_compute
{
namespace
{
-std::pair<Coordinates, Coordinates> compute_start_end_slice_coordinates(const ITensorInfo &output_info, const PadStrideInfo &deconv_info, bool is_nchw)
+std::pair<Coordinates, Coordinates>
+compute_start_end_slice_coordinates(const ITensorInfo &output_info, const PadStrideInfo &deconv_info, bool is_nchw)
{
Coordinates start;
Coordinates end;
- if(is_nchw)
+ if (is_nchw)
{
start.set(0, deconv_info.pad_left());
start.set(1, deconv_info.pad_top());
@@ -67,13 +64,16 @@ std::pair<Coordinates, Coordinates> compute_start_end_slice_coordinates(const IT
end.set(2, output_info.dimension(2) - deconv_info.pad_bottom());
}
- return { start, end };
+ return {start, end};
}
-Status construct_gemmlowp_output_stage(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, GEMMLowpOutputStageInfo &output_stage_info)
+Status construct_gemmlowp_output_stage(const ITensorInfo *input,
+ const ITensorInfo *weights,
+ const ITensorInfo *output,
+ GEMMLowpOutputStageInfo &output_stage_info)
{
const auto data_type = input->data_type();
- if(is_data_type_quantized_asymmetric(data_type))
+ if (is_data_type_quantized_asymmetric(data_type))
{
const UniformQuantizationInfo iq_info = input->quantization_info().uniform();
const UniformQuantizationInfo wq_info = weights->quantization_info().uniform();
@@ -82,7 +82,8 @@ Status construct_gemmlowp_output_stage(const ITensorInfo *input, const ITensorIn
float multiplier = iq_info.scale * wq_info.scale / oq_info.scale;
int output_multiplier(0);
int output_shift(0);
- ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
+ ARM_COMPUTE_RETURN_ON_ERROR(
+ quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
output_stage_info.gemmlowp_multiplier = output_multiplier;
@@ -126,15 +127,21 @@ CLGEMMDeconvolutionLayer::CLGEMMDeconvolutionLayer(std::shared_ptr<IMemoryManage
CLGEMMDeconvolutionLayer::~CLGEMMDeconvolutionLayer() = default;
-Status CLGEMMDeconvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, const ITensorInfo *output, const PadStrideInfo &deconv_info)
+Status CLGEMMDeconvolutionLayer::validate(const ITensorInfo *input,
+ const ITensorInfo *weights,
+ const ITensorInfo *bias,
+ const ITensorInfo *output,
+ const PadStrideInfo &deconv_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32, DataType::F16, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32, DataType::F16, DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights);
DataLayout data_layout = input->data_layout();
- const bool padded_input = deconv_info.pad_bottom() > 0 || deconv_info.pad_left() > 0 || deconv_info.pad_right() > 0 || deconv_info.pad_top() > 0;
+ const bool padded_input = deconv_info.pad_bottom() > 0 || deconv_info.pad_left() > 0 ||
+ deconv_info.pad_right() > 0 || deconv_info.pad_top() > 0;
const bool is_nchw = input->data_layout() == DataLayout::NCHW;
const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
@@ -148,21 +155,31 @@ Status CLGEMMDeconvolutionLayer::validate(const ITensorInfo *input, const ITenso
TensorShape nhwc_weights_shape = weights->tensor_shape();
TensorShape nhwc_input_shape = input->tensor_shape();
- if(is_nchw)
+ if (is_nchw)
{
permute(nhwc_weights_shape, PermutationVector(2, 0, 1));
permute(nhwc_input_shape, PermutationVector(2, 0, 1));
- TensorInfo nhwc_input_info = input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(nhwc_input_shape).set_data_layout(DataLayout::NCHW);
+ TensorInfo nhwc_input_info = input->clone()
+ ->set_is_resizable(true)
+ .reset_padding()
+ .set_tensor_shape(nhwc_input_shape)
+ .set_data_layout(DataLayout::NCHW);
- TensorInfo nhwc_weights_info = weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(nhwc_weights_shape).set_data_layout(DataLayout::NCHW);
+ TensorInfo nhwc_weights_info = weights->clone()
+ ->set_is_resizable(true)
+ .reset_padding()
+ .set_tensor_shape(nhwc_weights_shape)
+ .set_data_layout(DataLayout::NCHW);
CLPermute::validate(weights, &nhwc_weights_info, PermutationVector(2, 0, 1));
CLPermute::validate(input, &nhwc_input_info, PermutationVector(2, 0, 1));
}
- const TensorShape reshaped_shape = TensorShape(nhwc_weights_shape[0], nhwc_weights_shape[1] * nhwc_weights_shape[2] * nhwc_weights_shape[3]);
- const TensorInfo reshaped_info = weights->clone()->set_tensor_shape(reshaped_shape).set_data_layout(DataLayout::NCHW).set_is_resizable(true);
+ const TensorShape reshaped_shape =
+ TensorShape(nhwc_weights_shape[0], nhwc_weights_shape[1] * nhwc_weights_shape[2] * nhwc_weights_shape[3]);
+ const TensorInfo reshaped_info =
+ weights->clone()->set_tensor_shape(reshaped_shape).set_data_layout(DataLayout::NCHW).set_is_resizable(true);
ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayer::validate(weights, &reshaped_info));
TensorShape transposed_shape(reshaped_shape[1], reshaped_shape[0]);
@@ -170,76 +187,95 @@ Status CLGEMMDeconvolutionLayer::validate(const ITensorInfo *input, const ITenso
ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(&reshaped_info, &reshaped_t_info));
TensorShape gemm_output_shape(weights->dimension(idx_w) * weights->dimension(idx_h) * weights->dimension(idx_b),
- input->dimension(idx_w),
- input->dimension(idx_h),
- input->dimension(idx_b));
+ input->dimension(idx_w), input->dimension(idx_h), input->dimension(idx_b));
TensorInfo gemm_output_info = reshaped_t_info.clone()->set_tensor_shape(gemm_output_shape).set_is_resizable(true);
GEMMInfo gemm_info(false, false, true, input->dimension(idx_h), true);
GEMMLowpOutputStageInfo output_stage_info;
- if(is_quantized)
+ if (is_quantized)
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyCore::validate(&input->clone()->set_tensor_shape(nhwc_input_shape), &reshaped_t_info, nullptr, &gemm_output_info.set_data_type(DataType::S32),
- gemm_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyCore::validate(
+ &input->clone()->set_tensor_shape(nhwc_input_shape), &reshaped_t_info, nullptr,
+ &gemm_output_info.set_data_type(DataType::S32), gemm_info));
ARM_COMPUTE_RETURN_ON_ERROR(construct_gemmlowp_output_stage(input, weights, output, output_stage_info));
}
else
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(&input->clone()->set_tensor_shape(nhwc_input_shape).set_is_resizable(true), &reshaped_t_info, nullptr, &gemm_output_info, 1.0f, 0.0f, gemm_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(
+ CLGEMM::validate(&input->clone()->set_tensor_shape(nhwc_input_shape).set_is_resizable(true),
+ &reshaped_t_info, nullptr, &gemm_output_info, 1.0f, 0.0f, gemm_info));
}
const PadStrideInfo stride_info(deconv_info.stride().first, deconv_info.stride().second);
- auto out_dims = deconvolution_output_dimensions(input->dimension(idx_w), input->dimension(idx_h), weights->dimension(idx_w), weights->dimension(idx_h), stride_info);
- const TensorShape deconv_shape = misc::shape_calculator::compute_deconvolution_output_shape(out_dims, *input, *weights);
- TensorInfo col2im_output_info = gemm_output_info.clone()->set_tensor_shape(deconv_shape).set_is_resizable(true);
+ auto out_dims = deconvolution_output_dimensions(input->dimension(idx_w), input->dimension(idx_h),
+ weights->dimension(idx_w), weights->dimension(idx_h), stride_info);
+ const TensorShape deconv_shape =
+ misc::shape_calculator::compute_deconvolution_output_shape(out_dims, *input, *weights);
+ TensorInfo col2im_output_info = gemm_output_info.clone()->set_tensor_shape(deconv_shape).set_is_resizable(true);
- if(padded_input && is_quantized)
+ if (padded_input && is_quantized)
{
const auto start_end = compute_start_end_slice_coordinates(col2im_output_info, deconv_info, is_nchw);
- ARM_COMPUTE_RETURN_ON_ERROR(CLDeconvolutionReshapeOutputKernel::validate(&gemm_output_info, bias, &col2im_output_info, input, weights, deconv_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&col2im_output_info, nullptr, &col2im_output_info.clone()->set_is_resizable(true).set_data_type(input->data_type()), output_stage_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLSlice::validate(&col2im_output_info.clone()->set_is_resizable(true).set_data_type(input->data_type()), output, start_end.first, start_end.second));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLDeconvolutionReshapeOutputKernel::validate(
+ &gemm_output_info, bias, &col2im_output_info, input, weights, deconv_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(
+ &col2im_output_info, nullptr,
+ &col2im_output_info.clone()->set_is_resizable(true).set_data_type(input->data_type()), output_stage_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(
+ CLSlice::validate(&col2im_output_info.clone()->set_is_resizable(true).set_data_type(input->data_type()),
+ output, start_end.first, start_end.second));
}
- else if(padded_input)
+ else if (padded_input)
{
const auto start_end = compute_start_end_slice_coordinates(col2im_output_info, deconv_info, is_nchw);
- ARM_COMPUTE_RETURN_ON_ERROR(CLDeconvolutionReshapeOutputKernel::validate(&gemm_output_info, bias, &col2im_output_info, input, weights, deconv_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLDeconvolutionReshapeOutputKernel::validate(
+ &gemm_output_info, bias, &col2im_output_info, input, weights, deconv_info));
ARM_COMPUTE_RETURN_ON_ERROR(CLSlice::validate(&col2im_output_info, output, start_end.first, start_end.second));
}
- else if(is_quantized)
+ else if (is_quantized)
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLDeconvolutionReshapeOutputKernel::validate(&gemm_output_info, bias, &col2im_output_info, input, weights, deconv_info));
- ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&col2im_output_info, nullptr, output, output_stage_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLDeconvolutionReshapeOutputKernel::validate(
+ &gemm_output_info, bias, &col2im_output_info, input, weights, deconv_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(
+ CLGEMMLowpOutputStage::validate(&col2im_output_info, nullptr, output, output_stage_info));
}
else
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLDeconvolutionReshapeOutputKernel::validate(&gemm_output_info, bias, output, input, weights, deconv_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(
+ CLDeconvolutionReshapeOutputKernel::validate(&gemm_output_info, bias, output, input, weights, deconv_info));
}
return Status{};
}
-void CLGEMMDeconvolutionLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const PadStrideInfo &deconv_info)
+void CLGEMMDeconvolutionLayer::configure(const ICLTensor *input,
+ const ICLTensor *weights,
+ const ICLTensor *bias,
+ ICLTensor *output,
+ const PadStrideInfo &deconv_info)
{
configure(CLKernelLibrary::get().get_compile_context(), input, weights, bias, output, deconv_info);
}
-void CLGEMMDeconvolutionLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output,
- const PadStrideInfo &deconv_info)
+void CLGEMMDeconvolutionLayer::configure(const CLCompileContext &compile_context,
+ const ICLTensor *input,
+ const ICLTensor *weights,
+ const ICLTensor *bias,
+ ICLTensor *output,
+ const PadStrideInfo &deconv_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
- ARM_COMPUTE_ERROR_THROW_ON(CLGEMMDeconvolutionLayer::validate(input->info(),
- weights->info(),
- bias != nullptr ? bias->info() : nullptr,
- output->info(),
- deconv_info));
+ ARM_COMPUTE_ERROR_THROW_ON(CLGEMMDeconvolutionLayer::validate(
+ input->info(), weights->info(), bias != nullptr ? bias->info() : nullptr, output->info(), deconv_info));
+ ARM_COMPUTE_LOG_PARAMS(input, weights, bias, output, deconv_info);
_original_weights = weights;
- _padded_input = deconv_info.pad_bottom() > 0 || deconv_info.pad_left() > 0 || deconv_info.pad_right() > 0 || deconv_info.pad_top() > 0;
- _is_nchw = input->info()->data_layout() == DataLayout::NCHW;
- _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
+ _padded_input = deconv_info.pad_bottom() > 0 || deconv_info.pad_left() > 0 || deconv_info.pad_right() > 0 ||
+ deconv_info.pad_top() > 0;
+ _is_nchw = input->info()->data_layout() == DataLayout::NCHW;
+ _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
const ICLTensor *input_to_use = input;
const ICLTensor *weights_to_use = weights;
@@ -248,7 +284,7 @@ void CLGEMMDeconvolutionLayer::configure(const CLCompileContext &compile_context
// do an outer product in NCHW and then an accumulation through a reduction. This would have two
// drawbacks: first, the outer product is less efficient than a full GEMM. Second, the reduction
// might be slower than GEMM.
- if(_is_nchw)
+ if (_is_nchw)
{
_memory_group.manage(&_permuted_input);
_permute_input_to_nhwc.configure(compile_context, input, &_permuted_input, PermutationVector(2U, 0U, 1U));
@@ -260,10 +296,11 @@ void CLGEMMDeconvolutionLayer::configure(const CLCompileContext &compile_context
}
// Reshape the input weights. The weights will be reshaped only once during the call to prepare()
- _reshaped_weights.allocator()->init(TensorInfo(TensorShape(weights_to_use->info()->dimension(0),
- weights_to_use->info()->dimension(1) * weights_to_use->info()->dimension(2) * weights_to_use->info()->dimension(3)),
- 1,
- input->info()->data_type(), weights->info()->quantization_info()));
+ _reshaped_weights.allocator()->init(
+ TensorInfo(TensorShape(weights_to_use->info()->dimension(0), weights_to_use->info()->dimension(1) *
+ weights_to_use->info()->dimension(2) *
+ weights_to_use->info()->dimension(3)),
+ 1, input->info()->data_type(), weights->info()->quantization_info()));
_reshape_weights.configure(compile_context, weights_to_use, &_reshaped_weights);
_transpose_weights.configure(compile_context, &_reshaped_weights, &_reshaped_weights_t);
@@ -272,15 +309,17 @@ void CLGEMMDeconvolutionLayer::configure(const CLCompileContext &compile_context
GEMMInfo gemm_info(false, false, true, input->info()->dimension(idx_h), true);
// Configure output stage for asymmetric quantized types
- if(_is_quantized)
+ if (_is_quantized)
{
// gemmlowp adds the offsets (instead of subtracting them). Thus, we need to negate the original
// and restore them back to make it work properly.
QuantizationInfo iq_info = input->info()->quantization_info();
QuantizationInfo wq_info = weights->info()->quantization_info();
- input_to_use->info()->set_quantization_info(QuantizationInfo(iq_info.uniform().scale, -iq_info.uniform().offset));
- _reshaped_weights_t.info()->set_quantization_info(QuantizationInfo(wq_info.uniform().scale, -wq_info.uniform().offset));
+ input_to_use->info()->set_quantization_info(
+ QuantizationInfo(iq_info.uniform().scale, -iq_info.uniform().offset));
+ _reshaped_weights_t.info()->set_quantization_info(
+ QuantizationInfo(wq_info.uniform().scale, -wq_info.uniform().offset));
_mm_gemmlowp.configure(compile_context, input_to_use, &_reshaped_weights_t, nullptr, &_gemm_output, gemm_info);
@@ -289,10 +328,11 @@ void CLGEMMDeconvolutionLayer::configure(const CLCompileContext &compile_context
}
else
{
- _mm_gemm.configure(compile_context, input_to_use, &_reshaped_weights_t, nullptr, &_gemm_output, 1.f, 0.0f, gemm_info);
+ _mm_gemm.configure(compile_context, input_to_use, &_reshaped_weights_t, nullptr, &_gemm_output, 1.f, 0.0f,
+ gemm_info);
}
- if(_is_nchw)
+ if (_is_nchw)
{
_permuted_input.allocator()->allocate();
}
@@ -301,7 +341,7 @@ void CLGEMMDeconvolutionLayer::configure(const CLCompileContext &compile_context
ICLTensor *slice_output = nullptr;
ICLTensor *output_stage_output = nullptr;
- if(_padded_input && _is_quantized)
+ if (_padded_input && _is_quantized)
{
_memory_group.manage(&_slice_gemm_input);
_memory_group.manage(&_gemmlowp_final);
@@ -309,13 +349,13 @@ void CLGEMMDeconvolutionLayer::configure(const CLCompileContext &compile_context
output_stage_output = &_slice_gemm_input;
slice_output = output;
}
- else if(_padded_input)
+ else if (_padded_input)
{
_memory_group.manage(&_slice_gemm_input);
deconv_reshape_output = &_slice_gemm_input;
slice_output = output;
}
- else if(_is_quantized)
+ else if (_is_quantized)
{
_memory_group.manage(&_gemmlowp_final);
deconv_reshape_output = &_gemmlowp_final;
@@ -327,21 +367,24 @@ void CLGEMMDeconvolutionLayer::configure(const CLCompileContext &compile_context
}
// Configure a Col2Im call to reshape the output of GEMM
- _deconv_reshape->configure(compile_context, &_gemm_output, bias, deconv_reshape_output, input->info(), weights->info(), deconv_info);
+ _deconv_reshape->configure(compile_context, &_gemm_output, bias, deconv_reshape_output, input->info(),
+ weights->info(), deconv_info);
_gemm_output.allocator()->allocate();
- if(_is_quantized)
+ if (_is_quantized)
{
GEMMLowpOutputStageInfo output_stage_info;
construct_gemmlowp_output_stage(input->info(), weights->info(), output->info(), output_stage_info);
- _gemmlowp_output_stage.configure(compile_context, &_gemmlowp_final, nullptr, output_stage_output, output_stage_info);
+ _gemmlowp_output_stage.configure(compile_context, &_gemmlowp_final, nullptr, output_stage_output,
+ output_stage_info);
_gemmlowp_final.allocator()->allocate();
}
// If the input was padded, the output needs to be sliced.
- if(_padded_input)
+ if (_padded_input)
{
- const auto start_end = compute_start_end_slice_coordinates(*deconv_reshape_output->info(), deconv_info, _is_nchw);
+ const auto start_end =
+ compute_start_end_slice_coordinates(*deconv_reshape_output->info(), deconv_info, _is_nchw);
_slice_gemm.configure(compile_context, &_slice_gemm_input, slice_output, start_end.first, start_end.second);
_slice_gemm_input.allocator()->allocate();
}
@@ -353,12 +396,12 @@ void CLGEMMDeconvolutionLayer::run()
MemoryGroupResourceScope scope_mg(_memory_group);
- if(_is_nchw)
+ if (_is_nchw)
{
_permute_input_to_nhwc.run();
}
- if(_is_quantized)
+ if (_is_quantized)
{
_mm_gemmlowp.run();
}
@@ -369,12 +412,12 @@ void CLGEMMDeconvolutionLayer::run()
CLScheduler::get().enqueue(*_deconv_reshape, false);
- if(_is_quantized)
+ if (_is_quantized)
{
_gemmlowp_output_stage.run();
}
- if(_padded_input)
+ if (_padded_input)
{
_slice_gemm.run();
}
@@ -382,11 +425,11 @@ void CLGEMMDeconvolutionLayer::run()
void CLGEMMDeconvolutionLayer::prepare()
{
- if(!_is_prepared)
+ if (!_is_prepared)
{
ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
- if(_is_nchw)
+ if (_is_nchw)
{
_permuted_weights.allocator()->allocate();
_permute_weights_to_nhwc.run();
@@ -395,7 +438,7 @@ void CLGEMMDeconvolutionLayer::prepare()
_reshaped_weights.allocator()->allocate();
_reshape_weights.run();
- if(_is_nchw)
+ if (_is_nchw)
{
_permuted_weights.allocator()->free();
}
@@ -404,7 +447,7 @@ void CLGEMMDeconvolutionLayer::prepare()
_transpose_weights.run();
// Prepare gemm
- if(!_is_quantized)
+ if (!_is_quantized)
{
_mm_gemm.prepare();
}
@@ -414,7 +457,7 @@ void CLGEMMDeconvolutionLayer::prepare()
}
// Free resources
- if(!_reshaped_weights_t.is_used())
+ if (!_reshaped_weights_t.is_used())
{
_reshaped_weights_t.allocator()->free();
}