ArmNN  NotReleased
BatchNormalizationLayer.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
8 
9 namespace armnn
10 {
11 
12 class ScopedCpuTensorHandle;
13 
15 class BatchNormalizationLayer : public LayerWithParameters<BatchNormalizationDescriptor>
16 {
17 public:
19  std::unique_ptr<ScopedCpuTensorHandle> m_Mean;
21  std::unique_ptr<ScopedCpuTensorHandle> m_Variance;
23  std::unique_ptr<ScopedCpuTensorHandle> m_Beta;
25  std::unique_ptr<ScopedCpuTensorHandle> m_Gamma;
26 
31  virtual std::unique_ptr<IWorkload> CreateWorkload(const IWorkloadFactory& factory) const override;
32 
35  BatchNormalizationLayer* Clone(Graph& graph) const override;
36 
39  void ValidateTensorShapesFromInputs() override;
40 
41  void Accept(ILayerVisitor& visitor) const override;
42 
43 protected:
47  BatchNormalizationLayer(const BatchNormalizationDescriptor& param, const char* name);
48 
50  ~BatchNormalizationLayer() = default;
51 
55 };
56 
57 } // namespace
This layer represents a batch normalization operation.
std::unique_ptr< ScopedCpuTensorHandle > m_Gamma
A unique pointer to store Gamma values.
A BatchNormalizationDescriptor for the BatchNormalizationLayer.
std::vector< std::reference_wrapper< std::unique_ptr< ScopedCpuTensorHandle > >> ConstantTensors
Definition: Layer.hpp:356
std::unique_ptr< ScopedCpuTensorHandle > m_Variance
A unique pointer to store Variance values.
~BatchNormalizationLayer()=default
Default destructor.
virtual std::unique_ptr< IWorkload > CreateWorkload(const IWorkloadFactory &factory) const override
BatchNormalizationLayer * Clone(Graph &graph) const override
void Accept(ILayerVisitor &visitor) const override
std::unique_ptr< ScopedCpuTensorHandle > m_Mean
A unique pointer to store Mean values.
std::unique_ptr< ScopedCpuTensorHandle > m_Beta
A unique pointer to store Beta values.
BatchNormalizationLayer(const BatchNormalizationDescriptor &param, const char *name)
ConstantTensors GetConstantTensorsByRef() override