diff options
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 54 |
1 files changed, 53 insertions, 1 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index ef2ca48e04..ffdac43886 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017,2019-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "Serializer.hpp" @@ -97,6 +97,25 @@ serializer::ArgMinMaxFunction GetFlatBufferArgMinMaxFunction(armnn::ArgMinMaxFun } } +serializer::ScatterNdFunction GetFlatBufferScatterNdFunction(armnn::ScatterNdFunction function) +{ + switch (function) + { + case armnn::ScatterNdFunction::Update: + return serializer::ScatterNdFunction::ScatterNdFunction_Update; + case armnn::ScatterNdFunction::Add: + return serializer::ScatterNdFunction::ScatterNdFunction_Add; + case armnn::ScatterNdFunction::Sub: + return serializer::ScatterNdFunction::ScatterNdFunction_Sub; + case armnn::ScatterNdFunction::Max: + return serializer::ScatterNdFunction::ScatterNdFunction_Max; + case armnn::ScatterNdFunction::Min: + return serializer::ScatterNdFunction::ScatterNdFunction_Min; + default: + return serializer::ScatterNdFunction::ScatterNdFunction_Update; + } +} + uint32_t SerializerStrategy::GetSerializedId(LayerGuid guid) { if (m_guidMap.empty()) @@ -1347,6 +1366,32 @@ void SerializerStrategy::SerializeNormalizationLayer(const armnn::IConnectableLa CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_NormalizationLayer); } +void SerializerStrategy::SerializeScatterNdLayer(const armnn::IConnectableLayer* layer, + const armnn::ScatterNdDescriptor& descriptor, + const char* name) +{ + IgnoreUnused(name); + + // Create FlatBuffer BaseLayer + auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ScatterNd); + + auto flatBufferDesc = serializer::CreateScatterNdDescriptor( + m_flatBufferBuilder, + GetFlatBufferScatterNdFunction(descriptor.m_Function), + descriptor.m_InputEnabled, + descriptor.m_Axis, + descriptor.m_AxisEnabled); + + // Create the FlatBuffer TileLayer + auto flatBufferLayer = serializer::CreateScatterNdLayer( + m_flatBufferBuilder, + flatBufferBaseLayer, + flatBufferDesc); + + // Add the AnyLayer to the FlatBufferLayers + CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ScatterNdLayer); +} + void SerializerStrategy::SerializeShapeLayer(const armnn::IConnectableLayer* layer, const char* name) { @@ -2379,6 +2424,13 @@ void SerializerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer, SerializeReverseV2Layer(layer, name); break; } + case armnn::LayerType::ScatterNd: + { + const armnn::ScatterNdDescriptor& layerDescriptor = + static_cast<const armnn::ScatterNdDescriptor&>(descriptor); + SerializeScatterNdLayer(layer, layerDescriptor, name); + break; + } case armnn::LayerType::Shape: { SerializeShapeLayer(layer, name); |