aboutsummaryrefslogtreecommitdiff
path: root/src/graph/nodes/BatchNormalizationLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/nodes/BatchNormalizationLayer.cpp')
-rw-r--r--src/graph/nodes/BatchNormalizationLayer.cpp23
1 files changed, 23 insertions, 0 deletions
diff --git a/src/graph/nodes/BatchNormalizationLayer.cpp b/src/graph/nodes/BatchNormalizationLayer.cpp
index bce19016d7..a433f39dc4 100644
--- a/src/graph/nodes/BatchNormalizationLayer.cpp
+++ b/src/graph/nodes/BatchNormalizationLayer.cpp
@@ -56,6 +56,11 @@ std::unique_ptr<arm_compute::IFunction> BatchNormalizationLayer::instantiate_nod
_gamma.set_info(TensorInfo(TensorShape(batch_norm_size), in->info()->num_channels(), in->info()->data_type(), in->info()->fixed_point_position()));
}
+ bool mean_is_loaded = _mean.tensor() != nullptr;
+ bool var_is_loaded = _var.tensor() != nullptr;
+ bool gamma_is_loaded = _gamma.tensor() != nullptr;
+ bool beta_is_loaded = _beta.tensor() != nullptr;
+
// Create node context
NodeContext node_ctx(OperationType::BatchNormalizationLayer);
node_ctx.set_target(_target_hint);
@@ -67,6 +72,24 @@ std::unique_ptr<arm_compute::IFunction> BatchNormalizationLayer::instantiate_nod
node_ctx.add_output(out);
node_ctx.add_parameter<float>("epsilon", _epsilon);
+ // Fill tensors
+ if(!mean_is_loaded)
+ {
+ _mean.allocate_and_fill_if_needed();
+ }
+ if(!var_is_loaded)
+ {
+ _var.allocate_and_fill_if_needed();
+ }
+ if(!gamma_is_loaded)
+ {
+ _gamma.allocate_and_fill_if_needed();
+ }
+ if(!beta_is_loaded)
+ {
+ _beta.allocate_and_fill_if_needed();
+ }
+
// Get function
return OperationRegistry::get().find_operation(OperationType::BatchNormalizationLayer, _target_hint)->configure(node_ctx);
} \ No newline at end of file