aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/Serializer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r--src/armnnSerializer/Serializer.cpp29
1 files changed, 29 insertions, 0 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index d40cdfa591..423706ceb3 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -170,6 +170,35 @@ void SerializerVisitor::VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer*
CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_BatchToSpaceNdLayer);
}
+void SerializerVisitor::VisitBatchNormalizationLayer(const armnn::IConnectableLayer* layer,
+ const armnn::BatchNormalizationDescriptor& batchNormDescriptor,
+ const armnn::ConstTensor& mean,
+ const armnn::ConstTensor& variance,
+ const armnn::ConstTensor& beta,
+ const armnn::ConstTensor& gamma,
+ const char* name)
+{
+ auto fbBatchNormalizationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchNormalization);
+ auto fbBatchNormalizationDescriptor = serializer::CreateBatchNormalizationDescriptor(
+ m_flatBufferBuilder,
+ batchNormDescriptor.m_Eps,
+ GetFlatBufferDataLayout(batchNormDescriptor.m_DataLayout));
+
+ auto fbMeanConstTensorInfo = CreateConstTensorInfo(mean);
+ auto fbVarianceConstTensorInfo = CreateConstTensorInfo(variance);
+ auto fbBetaConstTensorInfo = CreateConstTensorInfo(beta);
+ auto fbGammaConstTensorInfo = CreateConstTensorInfo(gamma);
+ auto fbBatchNormalizationLayer = serializer::CreateBatchNormalizationLayer(m_flatBufferBuilder,
+ fbBatchNormalizationBaseLayer,
+ fbBatchNormalizationDescriptor,
+ fbMeanConstTensorInfo,
+ fbVarianceConstTensorInfo,
+ fbBetaConstTensorInfo,
+ fbGammaConstTensorInfo);
+
+ CreateAnyLayer(fbBatchNormalizationLayer.o, serializer::Layer::Layer_BatchNormalizationLayer);
+}
+
// Build FlatBuffer for Constant Layer
void SerializerVisitor::VisitConstantLayer(const armnn::IConnectableLayer* layer,
const armnn::ConstTensor& input,