diff options
Diffstat (limited to 'src/runtime/CL/functions/CLFuseBatchNormalization.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLFuseBatchNormalization.cpp | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/src/runtime/CL/functions/CLFuseBatchNormalization.cpp b/src/runtime/CL/functions/CLFuseBatchNormalization.cpp index 72dd27e3cc..6deecdc089 100644 --- a/src/runtime/CL/functions/CLFuseBatchNormalization.cpp +++ b/src/runtime/CL/functions/CLFuseBatchNormalization.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -41,7 +41,15 @@ void CLFuseBatchNormalization::configure(const ICLTensor *input_weights, const I const ICLTensor *input_bias, const ICLTensor *bn_beta, const ICLTensor *bn_gamma, float epsilon, FuseBatchNormalizationType fbn_type) { - _fuse_bn_kernel.configure(input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type); + configure(CLKernelLibrary::get().get_compile_context(), input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type); +} + +void CLFuseBatchNormalization::configure(const CLCompileContext &compile_context, const ICLTensor *input_weights, const ICLTensor *bn_mean, const ICLTensor *bn_var, + ICLTensor *fused_weights, ICLTensor *fused_bias, + const ICLTensor *input_bias, const ICLTensor *bn_beta, const ICLTensor *bn_gamma, + float epsilon, FuseBatchNormalizationType fbn_type) +{ + _fuse_bn_kernel.configure(compile_context, input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type); } Status CLFuseBatchNormalization::validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var, |