diff options
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r-- | src/armnnSerializer/ArmnnSchema.fbs | 23 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 15 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.hpp | 4 | ||||
-rw-r--r-- | src/armnnSerializer/SerializerUtils.cpp | 16 | ||||
-rw-r--r-- | src/armnnSerializer/SerializerUtils.hpp | 3 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 68 |
6 files changed, 126 insertions, 3 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index e1b6e1f768..1f71ce19f2 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -159,7 +159,8 @@ enum LayerType : uint { Transpose = 55, QLstm = 56, Fill = 57, - Rank = 58 + Rank = 58, + LogicalBinary = 59 } // Base layer table to be used as part of other layers @@ -270,7 +271,8 @@ enum UnaryOperation : byte { Rsqrt = 1, Sqrt = 2, Exp = 3, - Neg = 4 + Neg = 4, + LogicalNot = 5 } table ElementwiseUnaryDescriptor { @@ -362,6 +364,20 @@ table L2NormalizationDescriptor { eps:float = 1e-12; } +enum LogicalBinaryOperation : byte { + LogicalAnd = 0, + LogicalOr = 1 +} + +table LogicalBinaryDescriptor { + operation:LogicalBinaryOperation; +} + +table LogicalBinaryLayer { + base:LayerBase; + descriptor:LogicalBinaryDescriptor; +} + table MinimumLayer { base:LayerBase; } @@ -924,7 +940,8 @@ union Layer { TransposeLayer, QLstmLayer, FillLayer, - RankLayer + RankLayer, + LogicalBinaryLayer } table AnyLayer { diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index f85aae1b6a..379cce2109 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -564,6 +564,21 @@ void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer); } +void SerializerVisitor::VisitLogicalBinaryLayer(const armnn::IConnectableLayer* layer, + const armnn::LogicalBinaryDescriptor& descriptor, + const char* name) +{ + IgnoreUnused(name); + + auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_LogicalBinary); + auto fbDescriptor = serializer::CreateLogicalBinaryDescriptor( + m_flatBufferBuilder, + GetFlatBufferLogicalBinaryOperation(descriptor.m_Operation)); + + auto fbLayer = serializer::CreateLogicalBinaryLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor); + CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_LogicalBinaryLayer); +} + void SerializerVisitor::VisitLogSoftmaxLayer(const armnn::IConnectableLayer* layer, const armnn::LogSoftmaxDescriptor& logSoftmaxDescriptor, const char* name) diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp index babecdc056..fa3447de21 100644 --- a/src/armnnSerializer/Serializer.hpp +++ b/src/armnnSerializer/Serializer.hpp @@ -158,6 +158,10 @@ public: const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor, const char* name = nullptr) override; + void VisitLogicalBinaryLayer(const armnn::IConnectableLayer* layer, + const armnn::LogicalBinaryDescriptor& descriptor, + const char* name = nullptr) override; + void VisitLogSoftmaxLayer(const armnn::IConnectableLayer* layer, const armnn::LogSoftmaxDescriptor& logSoftmaxDescriptor, const char* name = nullptr) override; diff --git a/src/armnnSerializer/SerializerUtils.cpp b/src/armnnSerializer/SerializerUtils.cpp index 5566abf7e4..045d6aac5c 100644 --- a/src/armnnSerializer/SerializerUtils.cpp +++ b/src/armnnSerializer/SerializerUtils.cpp @@ -28,6 +28,20 @@ armnnSerializer::ComparisonOperation GetFlatBufferComparisonOperation(armnn::Com } } +armnnSerializer::LogicalBinaryOperation GetFlatBufferLogicalBinaryOperation( + armnn::LogicalBinaryOperation logicalBinaryOperation) +{ + switch (logicalBinaryOperation) + { + case armnn::LogicalBinaryOperation::LogicalAnd: + return armnnSerializer::LogicalBinaryOperation::LogicalBinaryOperation_LogicalAnd; + case armnn::LogicalBinaryOperation::LogicalOr: + return armnnSerializer::LogicalBinaryOperation::LogicalBinaryOperation_LogicalOr; + default: + throw armnn::InvalidArgumentException("Logical Binary operation unknown"); + } +} + armnnSerializer::ConstTensorData GetFlatBufferConstTensorData(armnn::DataType dataType) { switch (dataType) @@ -98,6 +112,8 @@ armnnSerializer::UnaryOperation GetFlatBufferUnaryOperation(armnn::UnaryOperatio return armnnSerializer::UnaryOperation::UnaryOperation_Exp; case armnn::UnaryOperation::Neg: return armnnSerializer::UnaryOperation::UnaryOperation_Neg; + case armnn::UnaryOperation::LogicalNot: + return armnnSerializer::UnaryOperation::UnaryOperation_LogicalNot; default: throw armnn::InvalidArgumentException("Unary operation unknown"); } diff --git a/src/armnnSerializer/SerializerUtils.hpp b/src/armnnSerializer/SerializerUtils.hpp index edd48a5e25..a3cf5ba3c1 100644 --- a/src/armnnSerializer/SerializerUtils.hpp +++ b/src/armnnSerializer/SerializerUtils.hpp @@ -35,4 +35,7 @@ armnnSerializer::NormalizationAlgorithmMethod GetFlatBufferNormalizationAlgorith armnnSerializer::ResizeMethod GetFlatBufferResizeMethod(armnn::ResizeMethod method); +armnnSerializer::LogicalBinaryOperation GetFlatBufferLogicalBinaryOperation( + armnn::LogicalBinaryOperation logicalBinaryOperation); + } // namespace armnnSerializer diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index e00fb4dcde..6866391e0f 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -1626,6 +1626,74 @@ BOOST_AUTO_TEST_CASE(EnsureL2NormalizationBackwardCompatibility) deserializedNetwork->Accept(verifier); } +BOOST_AUTO_TEST_CASE(SerializeLogicalBinary) +{ + DECLARE_LAYER_VERIFIER_CLASS_WITH_DESCRIPTOR(LogicalBinary) + + const std::string layerName("logicalBinaryAnd"); + + const armnn::TensorShape shape{2, 1, 2, 2}; + + const armnn::TensorInfo inputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean); + const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean); + + armnn::LogicalBinaryDescriptor descriptor(armnn::LogicalBinaryOperation::LogicalAnd); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0); + armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1); + armnn::IConnectableLayer* const logicalBinaryLayer = network->AddLogicalBinaryLayer(descriptor, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayer0->GetOutputSlot(0).Connect(logicalBinaryLayer->GetInputSlot(0)); + inputLayer1->GetOutputSlot(0).Connect(logicalBinaryLayer->GetInputSlot(1)); + logicalBinaryLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer0->GetOutputSlot(0).SetTensorInfo(inputInfo); + inputLayer1->GetOutputSlot(0).SetTensorInfo(inputInfo); + logicalBinaryLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + BOOST_CHECK(deserializedNetwork); + + LogicalBinaryLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor); + deserializedNetwork->Accept(verifier); +} + +BOOST_AUTO_TEST_CASE(SerializeLogicalUnary) +{ + DECLARE_LAYER_VERIFIER_CLASS_WITH_DESCRIPTOR(ElementwiseUnary) + + const std::string layerName("elementwiseUnaryLogicalNot"); + + const armnn::TensorShape shape{2, 1, 2, 2}; + + const armnn::TensorInfo inputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean); + const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean); + + armnn::ElementwiseUnaryDescriptor descriptor(armnn::UnaryOperation::LogicalNot); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0); + armnn::IConnectableLayer* const elementwiseUnaryLayer = + network->AddElementwiseUnaryLayer(descriptor, layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(elementwiseUnaryLayer->GetInputSlot(0)); + elementwiseUnaryLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo); + elementwiseUnaryLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + + BOOST_CHECK(deserializedNetwork); + + ElementwiseUnaryLayerVerifier verifier(layerName, { inputInfo }, { outputInfo }, descriptor); + + deserializedNetwork->Accept(verifier); +} + BOOST_AUTO_TEST_CASE(SerializeLogSoftmax) { DECLARE_LAYER_VERIFIER_CLASS_WITH_DESCRIPTOR(LogSoftmax) |