aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs10
-rw-r--r--src/armnnSerializer/Serializer.cpp8
-rw-r--r--src/armnnSerializer/Serializer.hpp3
-rw-r--r--src/armnnSerializer/SerializerSupport.md1
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp36
5 files changed, 56 insertions, 2 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 552c2cc056..9b23d8508c 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -101,7 +101,8 @@ enum LayerType : uint {
Equal = 17,
Maximum = 18,
Normalization = 19,
- Pad = 20
+ Pad = 20,
+ Rsqrt = 21
}
// Base layer table to be used as part of other layers
@@ -334,6 +335,10 @@ table PadDescriptor {
padList:[uint];
}
+table RsqrtLayer {
+ base:LayerBase;
+}
+
union Layer {
ActivationLayer,
AdditionLayer,
@@ -355,7 +360,8 @@ union Layer {
EqualLayer,
MaximumLayer,
NormalizationLayer,
- PadLayer
+ PadLayer,
+ RsqrtLayer
}
table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 868a36d42e..5f9ca13198 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -379,6 +379,14 @@ void SerializerVisitor::VisitReshapeLayer(const armnn::IConnectableLayer* layer,
CreateAnyLayer(flatBufferReshapeLayer.o, serializer::Layer::Layer_ReshapeLayer);
}
+void SerializerVisitor::VisitRsqrtLayer(const armnn::IConnectableLayer* layer, const char* name)
+{
+ auto fbRsqrtBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Rsqrt);
+ auto fbRsqrtLayer = serializer::CreateRsqrtLayer(m_flatBufferBuilder, fbRsqrtBaseLayer);
+
+ CreateAnyLayer(fbRsqrtLayer.o, serializer::Layer::Layer_RsqrtLayer);
+}
+
// Build FlatBuffer for Softmax Layer
void SerializerVisitor::VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
const armnn::SoftmaxDescriptor& softmaxDescriptor,
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index ef56c25f2c..6e92a067b6 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -114,6 +114,9 @@ public:
const armnn::ReshapeDescriptor& reshapeDescriptor,
const char* name = nullptr) override;
+ void VisitRsqrtLayer(const armnn::IConnectableLayer* layer,
+ const char* name = nullptr) override;
+
void VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
const armnn::SoftmaxDescriptor& softmaxDescriptor,
const char* name = nullptr) override;
diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md
index a77e8860a2..b9bc0d4479 100644
--- a/src/armnnSerializer/SerializerSupport.md
+++ b/src/armnnSerializer/SerializerSupport.md
@@ -23,6 +23,7 @@ The Arm NN SDK Serializer currently supports the following layers:
* Permute
* Pooling2d
* Reshape
+* Rsqrt
* Softmax
* SpaceToBatchNd
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index 110bf0c581..515689a777 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1021,4 +1021,40 @@ BOOST_AUTO_TEST_CASE(SerializeDeserializePad)
{outputTensorInfo.GetShape()});
}
+BOOST_AUTO_TEST_CASE(SerializeRsqrt)
+{
+ class VerifyRsqrtName : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
+ {
+ public:
+ void VisitRsqrtLayer(const armnn::IConnectableLayer*, const char* name) override
+ {
+ BOOST_TEST(name == "rsqrt");
+ }
+ };
+
+ const armnn::TensorInfo tensorInfo({ 3, 1, 2 }, armnn::DataType::Float32);
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const rsqrtLayer = network->AddRsqrtLayer("rsqrt");
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+ inputLayer->GetOutputSlot(0).Connect(rsqrtLayer->GetInputSlot(0));
+ rsqrtLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+ inputLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+ rsqrtLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ VerifyRsqrtName nameChecker;
+ deserializedNetwork->Accept(nameChecker);
+
+ CheckDeserializedNetworkAgainstOriginal(*network,
+ *deserializedNetwork,
+ {tensorInfo.GetShape()},
+ {tensorInfo.GetShape()});
+}
+
BOOST_AUTO_TEST_SUITE_END()