diff options
author | Manuel Bottini <manuel.bottini@arm.com> | 2019-05-28 11:44:41 +0100 |
---|---|---|
committer | Manuel Bottini <manuel.bottini@arm.com> | 2019-06-13 16:01:42 +0000 |
commit | 2732cca12bac29e1515cee1db5005c73893c61b4 (patch) | |
tree | 050d4c20b51b2b642be21512f9b4a900e18ce88c /src/runtime/CL | |
parent | b3a0a60d0b570c58d84324059abb5caceae2561c (diff) | |
download | ComputeLibrary-2732cca12bac29e1515cee1db5005c73893c61b4.tar.gz |
COMPMID-2244: Extend CLFuseBatchNormalization to support DepthwiseConvolution weights
Change-Id: I7d1907f35cc4899379073759be2f7cce24e51e9d
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1327
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL')
-rw-r--r-- | src/runtime/CL/functions/CLFuseBatchNormalization.cpp | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/src/runtime/CL/functions/CLFuseBatchNormalization.cpp b/src/runtime/CL/functions/CLFuseBatchNormalization.cpp index 32e46787d3..72dd27e3cc 100644 --- a/src/runtime/CL/functions/CLFuseBatchNormalization.cpp +++ b/src/runtime/CL/functions/CLFuseBatchNormalization.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -36,20 +36,20 @@ CLFuseBatchNormalization::CLFuseBatchNormalization() { } -void CLFuseBatchNormalization::configure(const ICLTensor *conv_weights, const ICLTensor *bn_mean, const ICLTensor *bn_var, +void CLFuseBatchNormalization::configure(const ICLTensor *input_weights, const ICLTensor *bn_mean, const ICLTensor *bn_var, ICLTensor *fused_weights, ICLTensor *fused_bias, - const ICLTensor *conv_bias, const ICLTensor *bn_beta, const ICLTensor *bn_gamma, - float epsilon) + const ICLTensor *input_bias, const ICLTensor *bn_beta, const ICLTensor *bn_gamma, + float epsilon, FuseBatchNormalizationType fbn_type) { - _fuse_bn_kernel.configure(conv_weights, bn_mean, bn_var, fused_weights, fused_bias, conv_bias, bn_beta, bn_gamma, epsilon); + _fuse_bn_kernel.configure(input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type); } -Status CLFuseBatchNormalization::validate(const ITensorInfo *conv_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var, +Status CLFuseBatchNormalization::validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var, const ITensorInfo *fused_weights, const ITensorInfo *fused_bias, - const ITensorInfo *conv_bias, const ITensorInfo *bn_beta, const ITensorInfo *bn_gamma, - float epsilon) + const ITensorInfo *input_bias, const ITensorInfo *bn_beta, const ITensorInfo *bn_gamma, + float epsilon, FuseBatchNormalizationType fbn_type) { - return CLFuseBatchNormalizationKernel::validate(conv_weights, bn_mean, bn_var, fused_weights, fused_bias, conv_bias, bn_beta, bn_gamma, epsilon); + return CLFuseBatchNormalizationKernel::validate(input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type); } void CLFuseBatchNormalization::run() |