aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs24
-rw-r--r--src/armnnSerializer/Serializer.cpp54
-rw-r--r--src/armnnSerializer/Serializer.hpp6
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp44
4 files changed, 124 insertions, 4 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 131970e449..3a01c504a5 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -1,5 +1,5 @@
//
-// Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2019-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -64,6 +64,14 @@ enum ResizeMethod: byte {
Bilinear = 1,
}
+enum ScatterNdFunction: byte {
+ Update = 0,
+ Add = 1,
+ Sub = 2,
+ Max = 3,
+ Min = 4
+}
+
table TensorInfo {
dimensions:[uint];
dataType:DataType;
@@ -189,6 +197,7 @@ enum LayerType : uint {
ElementwiseBinary = 69,
ReverseV2 = 70,
Tile = 71,
+ ScatterNd = 72,
}
// Base layer table to be used as part of other layers
@@ -1066,6 +1075,18 @@ table TileLayer {
descriptor:TileDescriptor;
}
+table ScatterNdDescriptor {
+ m_Function:ScatterNdFunction = Update;
+ m_InputEnabled:bool = true;
+ m_Axis:int = 0;
+ m_AxisEnabled:bool = false;
+}
+
+table ScatterNdLayer {
+ base:LayerBase;
+ descriptor:ScatterNdDescriptor;
+}
+
union Layer {
ActivationLayer,
AdditionLayer,
@@ -1139,6 +1160,7 @@ union Layer {
ElementwiseBinaryLayer,
ReverseV2Layer,
TileLayer,
+ ScatterNdLayer,
}
table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index ef2ca48e04..ffdac43886 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2019-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "Serializer.hpp"
@@ -97,6 +97,25 @@ serializer::ArgMinMaxFunction GetFlatBufferArgMinMaxFunction(armnn::ArgMinMaxFun
}
}
+serializer::ScatterNdFunction GetFlatBufferScatterNdFunction(armnn::ScatterNdFunction function)
+{
+ switch (function)
+ {
+ case armnn::ScatterNdFunction::Update:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Update;
+ case armnn::ScatterNdFunction::Add:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Add;
+ case armnn::ScatterNdFunction::Sub:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Sub;
+ case armnn::ScatterNdFunction::Max:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Max;
+ case armnn::ScatterNdFunction::Min:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Min;
+ default:
+ return serializer::ScatterNdFunction::ScatterNdFunction_Update;
+ }
+}
+
uint32_t SerializerStrategy::GetSerializedId(LayerGuid guid)
{
if (m_guidMap.empty())
@@ -1347,6 +1366,32 @@ void SerializerStrategy::SerializeNormalizationLayer(const armnn::IConnectableLa
CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_NormalizationLayer);
}
+void SerializerStrategy::SerializeScatterNdLayer(const armnn::IConnectableLayer* layer,
+ const armnn::ScatterNdDescriptor& descriptor,
+ const char* name)
+{
+ IgnoreUnused(name);
+
+ // Create FlatBuffer BaseLayer
+ auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ScatterNd);
+
+ auto flatBufferDesc = serializer::CreateScatterNdDescriptor(
+ m_flatBufferBuilder,
+ GetFlatBufferScatterNdFunction(descriptor.m_Function),
+ descriptor.m_InputEnabled,
+ descriptor.m_Axis,
+ descriptor.m_AxisEnabled);
+
+ // Create the FlatBuffer TileLayer
+ auto flatBufferLayer = serializer::CreateScatterNdLayer(
+ m_flatBufferBuilder,
+ flatBufferBaseLayer,
+ flatBufferDesc);
+
+ // Add the AnyLayer to the FlatBufferLayers
+ CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ScatterNdLayer);
+}
+
void SerializerStrategy::SerializeShapeLayer(const armnn::IConnectableLayer* layer,
const char* name)
{
@@ -2379,6 +2424,13 @@ void SerializerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer,
SerializeReverseV2Layer(layer, name);
break;
}
+ case armnn::LayerType::ScatterNd:
+ {
+ const armnn::ScatterNdDescriptor& layerDescriptor =
+ static_cast<const armnn::ScatterNdDescriptor&>(descriptor);
+ SerializeScatterNdLayer(layer, layerDescriptor, name);
+ break;
+ }
case armnn::LayerType::Shape:
{
SerializeShapeLayer(layer, name);
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index afff66e21a..7434d63dd6 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2019-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -312,6 +312,10 @@ private:
const armnn::NormalizationDescriptor& normalizationDescriptor,
const char* name = nullptr);
+ void SerializeScatterNdLayer(const armnn::IConnectableLayer* layer,
+ const armnn::ScatterNdDescriptor& descriptor,
+ const char* name);
+
void SerializeShapeLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr);
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index bfe3fc6467..37acb0c1a5 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017,2020-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2020-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -3065,4 +3065,46 @@ TEST_CASE("SerializeTile")
deserializedNetwork->ExecuteStrategy(verifier);
}
+TEST_CASE("SerializeScatterNd")
+{
+ const std::string layerName("ScatterNd");
+ const armnn::TensorInfo inputInfo ({ 5 }, armnn::DataType::Float32);
+ const armnn::TensorInfo outputInfo ({ 5 }, armnn::DataType::Float32);
+ const armnn::TensorInfo indicesInfo({ 3, 1 }, armnn::DataType::Float32, 0.0f, 0, true);
+ const armnn::TensorInfo updatesInfo ({ 3 }, armnn::DataType::Float32,0.0f, 0, true);
+ std::vector<float> indicesData = { 0, 2, 3 };
+ const armnn::ConstTensor indices(indicesInfo, indicesData);
+
+ std::vector<float> updatesData = { 4, 5, 6 };
+ const armnn::ConstTensor updates(updatesInfo, updatesData);
+
+ armnn::ScatterNdDescriptor desc;
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const indicesLayer = network->AddConstantLayer(indices, "Indices");
+ armnn::IConnectableLayer* const updatesLayer = network->AddConstantLayer(updates, "Updates");
+ armnn::IConnectableLayer* const scatterNdLayer = network->AddScatterNdLayer(desc, layerName.c_str());
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+ inputLayer->GetOutputSlot(0).Connect(scatterNdLayer->GetInputSlot(0));
+ indicesLayer->GetOutputSlot(0).Connect(scatterNdLayer->GetInputSlot(1));
+ updatesLayer->GetOutputSlot(0).Connect(scatterNdLayer->GetInputSlot(2));
+ scatterNdLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+ inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
+ indicesLayer->GetOutputSlot(0).SetTensorInfo(indicesInfo);
+ updatesLayer->GetOutputSlot(0).SetTensorInfo(updatesInfo);
+ scatterNdLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ CHECK(deserializedNetwork);
+
+ LayerVerifierBaseWithDescriptor<armnn::ScatterNdDescriptor> verifier(layerName,
+ {inputInfo, indicesInfo, updatesInfo},
+ {outputInfo},
+ desc);
+ deserializedNetwork->ExecuteStrategy(verifier);
+}
+
}