aboutsummaryrefslogtreecommitdiff
path: root/src/graph/nodes/BatchNormalizationLayer.cpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2017-10-09 15:46:30 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commit27c9efb922832e5e6785a492e84a46934d9a47f8 (patch)
tree031e45ce8229c4801a8f8263a258cecbb2403763 /src/graph/nodes/BatchNormalizationLayer.cpp
parent63e5041e6ea0f7fb57a1cc349f1325785fa800fa (diff)
downloadComputeLibrary-27c9efb922832e5e6785a492e84a46934d9a47f8.tar.gz
COMPMID-554 Add Nodes
- BatchNormalization - DepthConvert - Dequantization - Flatten - Quantization - Reshape Change-Id: Ie01a04b7a6cc8e2b5481cf2345268e6871580d7f Reviewed-on: http://mpd-gerrit.cambridge.arm.com/91618 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/graph/nodes/BatchNormalizationLayer.cpp')
-rw-r--r--src/graph/nodes/BatchNormalizationLayer.cpp23
1 files changed, 23 insertions, 0 deletions
diff --git a/src/graph/nodes/BatchNormalizationLayer.cpp b/src/graph/nodes/BatchNormalizationLayer.cpp
index bce19016d7..a433f39dc4 100644
--- a/src/graph/nodes/BatchNormalizationLayer.cpp
+++ b/src/graph/nodes/BatchNormalizationLayer.cpp
@@ -56,6 +56,11 @@ std::unique_ptr<arm_compute::IFunction> BatchNormalizationLayer::instantiate_nod
_gamma.set_info(TensorInfo(TensorShape(batch_norm_size), in->info()->num_channels(), in->info()->data_type(), in->info()->fixed_point_position()));
}
+ bool mean_is_loaded = _mean.tensor() != nullptr;
+ bool var_is_loaded = _var.tensor() != nullptr;
+ bool gamma_is_loaded = _gamma.tensor() != nullptr;
+ bool beta_is_loaded = _beta.tensor() != nullptr;
+
// Create node context
NodeContext node_ctx(OperationType::BatchNormalizationLayer);
node_ctx.set_target(_target_hint);
@@ -67,6 +72,24 @@ std::unique_ptr<arm_compute::IFunction> BatchNormalizationLayer::instantiate_nod
node_ctx.add_output(out);
node_ctx.add_parameter<float>("epsilon", _epsilon);
+ // Fill tensors
+ if(!mean_is_loaded)
+ {
+ _mean.allocate_and_fill_if_needed();
+ }
+ if(!var_is_loaded)
+ {
+ _var.allocate_and_fill_if_needed();
+ }
+ if(!gamma_is_loaded)
+ {
+ _gamma.allocate_and_fill_if_needed();
+ }
+ if(!beta_is_loaded)
+ {
+ _beta.allocate_and_fill_if_needed();
+ }
+
// Get function
return OperationRegistry::get().find_operation(OperationType::BatchNormalizationLayer, _target_hint)->configure(node_ctx);
} \ No newline at end of file