diff options
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLFuseBatchNormalization.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLFuseBatchNormalization.h | 41 |
1 files changed, 31 insertions, 10 deletions
diff --git a/arm_compute/runtime/CL/functions/CLFuseBatchNormalization.h b/arm_compute/runtime/CL/functions/CLFuseBatchNormalization.h index cd75270392..2e777273cd 100644 --- a/arm_compute/runtime/CL/functions/CLFuseBatchNormalization.h +++ b/arm_compute/runtime/CL/functions/CLFuseBatchNormalization.h @@ -78,9 +78,16 @@ public: * @param[in] epsilon (Optional) Batch normalization layer epsilon parameter. Defaults to 0.001f. * @param[in] fbn_type (Optional) Fused batch normalization type. Defaults to Convolution. */ - void configure(const ICLTensor *input_weights, const ICLTensor *bn_mean, const ICLTensor *bn_var, ICLTensor *fused_weights, ICLTensor *fused_bias, - const ICLTensor *input_bias = nullptr, const ICLTensor *bn_beta = nullptr, const ICLTensor *bn_gamma = nullptr, - float epsilon = 0.001f, FuseBatchNormalizationType fbn_type = FuseBatchNormalizationType::CONVOLUTION); + void configure(const ICLTensor *input_weights, + const ICLTensor *bn_mean, + const ICLTensor *bn_var, + ICLTensor *fused_weights, + ICLTensor *fused_bias, + const ICLTensor *input_bias = nullptr, + const ICLTensor *bn_beta = nullptr, + const ICLTensor *bn_gamma = nullptr, + float epsilon = 0.001f, + FuseBatchNormalizationType fbn_type = FuseBatchNormalizationType::CONVOLUTION); /** Set the input and output tensors. * * @param[in] compile_context The compile context to be used. @@ -97,9 +104,17 @@ public: * @param[in] epsilon (Optional) Batch normalization layer epsilon parameter. Defaults to 0.001f. * @param[in] fbn_type (Optional) Fused batch normalization type. Defaults to Convolution. */ - void configure(const CLCompileContext &compile_context, const ICLTensor *input_weights, const ICLTensor *bn_mean, const ICLTensor *bn_var, ICLTensor *fused_weights, ICLTensor *fused_bias, - const ICLTensor *input_bias = nullptr, const ICLTensor *bn_beta = nullptr, const ICLTensor *bn_gamma = nullptr, - float epsilon = 0.001f, FuseBatchNormalizationType fbn_type = FuseBatchNormalizationType::CONVOLUTION); + void configure(const CLCompileContext &compile_context, + const ICLTensor *input_weights, + const ICLTensor *bn_mean, + const ICLTensor *bn_var, + ICLTensor *fused_weights, + ICLTensor *fused_bias, + const ICLTensor *input_bias = nullptr, + const ICLTensor *bn_beta = nullptr, + const ICLTensor *bn_gamma = nullptr, + float epsilon = 0.001f, + FuseBatchNormalizationType fbn_type = FuseBatchNormalizationType::CONVOLUTION); /** Static function to check if given info will lead to a valid configuration of @ref CLFuseBatchNormalization * * @param[in] input_weights Input weights tensor info for convolution or depthwise convolution layer. Data type supported: F16/F32. Data layout supported: NCHW, NHWC @@ -117,10 +132,16 @@ public: * * @return a status */ - static Status validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var, - const ITensorInfo *fused_weights, const ITensorInfo *fused_bias, - const ITensorInfo *input_bias = nullptr, const ITensorInfo *bn_beta = nullptr, const ITensorInfo *bn_gamma = nullptr, - float epsilon = 0.001f, FuseBatchNormalizationType fbn_type = FuseBatchNormalizationType::CONVOLUTION); + static Status validate(const ITensorInfo *input_weights, + const ITensorInfo *bn_mean, + const ITensorInfo *bn_var, + const ITensorInfo *fused_weights, + const ITensorInfo *fused_bias, + const ITensorInfo *input_bias = nullptr, + const ITensorInfo *bn_beta = nullptr, + const ITensorInfo *bn_gamma = nullptr, + float epsilon = 0.001f, + FuseBatchNormalizationType fbn_type = FuseBatchNormalizationType::CONVOLUTION); // Inherited methods overridden: void run() override; |