aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLFuseBatchNormalization.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLFuseBatchNormalization.cpp')
-rw-r--r--src/runtime/CL/functions/CLFuseBatchNormalization.cpp12
1 files changed, 10 insertions, 2 deletions
diff --git a/src/runtime/CL/functions/CLFuseBatchNormalization.cpp b/src/runtime/CL/functions/CLFuseBatchNormalization.cpp
index 72dd27e3cc..6deecdc089 100644
--- a/src/runtime/CL/functions/CLFuseBatchNormalization.cpp
+++ b/src/runtime/CL/functions/CLFuseBatchNormalization.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,7 +41,15 @@ void CLFuseBatchNormalization::configure(const ICLTensor *input_weights, const I
const ICLTensor *input_bias, const ICLTensor *bn_beta, const ICLTensor *bn_gamma,
float epsilon, FuseBatchNormalizationType fbn_type)
{
- _fuse_bn_kernel.configure(input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type);
+ configure(CLKernelLibrary::get().get_compile_context(), input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type);
+}
+
+void CLFuseBatchNormalization::configure(const CLCompileContext &compile_context, const ICLTensor *input_weights, const ICLTensor *bn_mean, const ICLTensor *bn_var,
+ ICLTensor *fused_weights, ICLTensor *fused_bias,
+ const ICLTensor *input_bias, const ICLTensor *bn_beta, const ICLTensor *bn_gamma,
+ float epsilon, FuseBatchNormalizationType fbn_type)
+{
+ _fuse_bn_kernel.configure(compile_context, input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type);
}
Status CLFuseBatchNormalization::validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var,