diff options
Diffstat (limited to 'arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h')
-rw-r--r-- | arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h | 40 |
1 files changed, 23 insertions, 17 deletions
diff --git a/arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h b/arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h index ec03bcc952..27e21cbc7e 100644 --- a/arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h +++ b/arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Arm Limited. + * Copyright (c) 2019, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -22,11 +22,12 @@ * SOFTWARE. */ -#ifndef ARM_COMPUTE_GRAPH_BACKENDS_FUSED_CONVOLUTION_BATCH_NORMAZLIZATION_FUNCTION_H -#define ARM_COMPUTE_GRAPH_BACKENDS_FUSED_CONVOLUTION_BATCH_NORMAZLIZATION_FUNCTION_H +#ifndef ACL_ARM_COMPUTE_GRAPH_BACKENDS_FUSEDCONVOLUTIONBATCHNORMALIZATIONFUNCTION_H +#define ACL_ARM_COMPUTE_GRAPH_BACKENDS_FUSEDCONVOLUTIONBATCHNORMALIZATIONFUNCTION_H #include "arm_compute/core/Types.h" #include "arm_compute/runtime/IFunction.h" +#include "arm_compute/runtime/IMemoryManager.h" namespace arm_compute { @@ -69,15 +70,19 @@ public: * @param[in] fused_act Activation layer information in case of a fused activation. * */ - void configure(TensorType *input, - TensorType *weights, - TensorType *bias, - TensorType *output, - const TensorType *mean, - const TensorType *var, - const TensorType *beta, - const TensorType *gamma, - float epsilon, const PadStrideInfo &conv_info, unsigned int num_groups, bool fast_math, ActivationLayerInfo const &fused_act) + void configure(TensorType *input, + TensorType *weights, + TensorType *bias, + TensorType *output, + const TensorType *mean, + const TensorType *var, + const TensorType *beta, + const TensorType *gamma, + float epsilon, + const PadStrideInfo &conv_info, + unsigned int num_groups, + bool fast_math, + ActivationLayerInfo const &fused_act) { // We don't run any validate, as we assume that the layers have been already validated const bool has_bias = (bias != nullptr); @@ -85,7 +90,7 @@ public: // We check if the layer has a bias. If yes, use it in-place. If not, we need to create one // as batch normalization might end up with a bias != 0 - if(has_bias) + if (has_bias) { _fused_batch_norm_layer.configure(weights, mean, var, nullptr, nullptr, bias, beta, gamma, epsilon); bias_to_use = bias; @@ -96,9 +101,10 @@ public: bias_to_use = &_fused_bias; } - _conv_layer.configure(input, weights, bias_to_use, output, conv_info, WeightsInfo(), Size2D(1U, 1U), fused_act, fast_math, num_groups); + _conv_layer.configure(input, weights, bias_to_use, output, conv_info, WeightsInfo(), Size2D(1U, 1U), fused_act, + fast_math, num_groups); - if(!has_bias) + if (!has_bias) { _fused_bias.allocator()->allocate(); } @@ -113,7 +119,7 @@ public: void prepare() { - if(!_is_prepared) + if (!_is_prepared) { _fused_batch_norm_layer.run(); _is_prepared = true; @@ -130,4 +136,4 @@ private: } // namespace graph } // namespace arm_compute -#endif /* ARM_COMPUTE_GRAPH_BACKENDS_FUSED_CONVOLUTION_BATCH_NORMAZLIZATION_FUNCTION_H */ +#endif // ACL_ARM_COMPUTE_GRAPH_BACKENDS_FUSEDCONVOLUTIONBATCHNORMALIZATIONFUNCTION_H |