aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/Serializer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r--src/armnnSerializer/Serializer.cpp54
1 files changed, 53 insertions, 1 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index ef2ca48e04..ffdac43886 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2019-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "Serializer.hpp"
@@ -97,6 +97,25 @@ serializer::ArgMinMaxFunction GetFlatBufferArgMinMaxFunction(armnn::ArgMinMaxFun
}
}
+serializer::ScatterNdFunction GetFlatBufferScatterNdFunction(armnn::ScatterNdFunction function)
+{
+ switch (function)
+ {
+ case armnn::ScatterNdFunction::Update:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Update;
+ case armnn::ScatterNdFunction::Add:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Add;
+ case armnn::ScatterNdFunction::Sub:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Sub;
+ case armnn::ScatterNdFunction::Max:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Max;
+ case armnn::ScatterNdFunction::Min:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Min;
+ default:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Update;
+ }
+}
+
uint32_t SerializerStrategy::GetSerializedId(LayerGuid guid)
{
if (m_guidMap.empty())
@@ -1347,6 +1366,32 @@ void SerializerStrategy::SerializeNormalizationLayer(const armnn::IConnectableLa
CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_NormalizationLayer);
}
+void SerializerStrategy::SerializeScatterNdLayer(const armnn::IConnectableLayer* layer,
+ const armnn::ScatterNdDescriptor& descriptor,
+ const char* name)
+{
+ IgnoreUnused(name);
+
+ // Create FlatBuffer BaseLayer
+ auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ScatterNd);
+
+ auto flatBufferDesc = serializer::CreateScatterNdDescriptor(
+ m_flatBufferBuilder,
+ GetFlatBufferScatterNdFunction(descriptor.m_Function),
+ descriptor.m_InputEnabled,
+ descriptor.m_Axis,
+ descriptor.m_AxisEnabled);
+
+ // Create the FlatBuffer TileLayer
+ auto flatBufferLayer = serializer::CreateScatterNdLayer(
+ m_flatBufferBuilder,
+ flatBufferBaseLayer,
+ flatBufferDesc);
+
+ // Add the AnyLayer to the FlatBufferLayers
+ CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ScatterNdLayer);
+}
+
void SerializerStrategy::SerializeShapeLayer(const armnn::IConnectableLayer* layer,
const char* name)
{
@@ -2379,6 +2424,13 @@ void SerializerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer,
SerializeReverseV2Layer(layer, name);
break;
}
+ case armnn::LayerType::ScatterNd:
+ {
+ const armnn::ScatterNdDescriptor& layerDescriptor =
+ static_cast<const armnn::ScatterNdDescriptor&>(descriptor);
+ SerializeScatterNdLayer(layer, layerDescriptor, name);
+ break;
+ }
case armnn::LayerType::Shape:
{
SerializeShapeLayer(layer, name);