aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnDeserializer')
-rw-r--r--src/armnnDeserializer/Deserializer.cpp50
-rw-r--r--src/armnnDeserializer/Deserializer.hpp3
-rw-r--r--src/armnnDeserializer/DeserializerSupport.md13
3 files changed, 62 insertions, 4 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 0b18ccd051..6779f1eb06 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -191,6 +191,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer)
m_ParserFunctions[Layer_ArgMinMaxLayer] = &Deserializer::ParseArgMinMax;
m_ParserFunctions[Layer_BatchToSpaceNdLayer] = &Deserializer::ParseBatchToSpaceNd;
m_ParserFunctions[Layer_BatchNormalizationLayer] = &Deserializer::ParseBatchNormalization;
+ m_ParserFunctions[Layer_ComparisonLayer] = &Deserializer::ParseComparison;
m_ParserFunctions[Layer_ConcatLayer] = &Deserializer::ParseConcat;
m_ParserFunctions[Layer_ConstantLayer] = &Deserializer::ParseConstant;
m_ParserFunctions[Layer_Convolution2dLayer] = &Deserializer::ParseConvolution2d;
@@ -255,6 +256,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt
return graphPtr->layers()->Get(layerIndex)->layer_as_BatchToSpaceNdLayer()->base();
case Layer::Layer_BatchNormalizationLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_BatchNormalizationLayer()->base();
+ case Layer::Layer_ComparisonLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_ComparisonLayer()->base();
case Layer::Layer_ConcatLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_ConcatLayer()->base();
case Layer::Layer_ConstantLayer:
@@ -428,6 +431,26 @@ armnn::ArgMinMaxFunction ToArgMinMaxFunction(armnnSerializer::ArgMinMaxFunction
}
}
+armnn::ComparisonOperation ToComparisonOperation(armnnSerializer::ComparisonOperation operation)
+{
+ switch (operation)
+ {
+ case armnnSerializer::ComparisonOperation::ComparisonOperation_Equal:
+ return armnn::ComparisonOperation::Equal;
+ case armnnSerializer::ComparisonOperation::ComparisonOperation_Greater:
+ return armnn::ComparisonOperation::Greater;
+ case armnnSerializer::ComparisonOperation::ComparisonOperation_GreaterOrEqual:
+ return armnn::ComparisonOperation::GreaterOrEqual;
+ case armnnSerializer::ComparisonOperation::ComparisonOperation_Less:
+ return armnn::ComparisonOperation::Less;
+ case armnnSerializer::ComparisonOperation::ComparisonOperation_LessOrEqual:
+ return armnn::ComparisonOperation::LessOrEqual;
+ case armnnSerializer::ComparisonOperation::ComparisonOperation_NotEqual:
+ default:
+ return armnn::ComparisonOperation::NotEqual;
+ }
+}
+
armnn::ResizeMethod ToResizeMethod(armnnSerializer::ResizeMethod method)
{
switch (method)
@@ -1436,6 +1459,33 @@ const armnnSerializer::OriginsDescriptor* GetOriginsDescriptor(const armnnSerial
}
}
+void Deserializer::ParseComparison(GraphPtr graph, unsigned int layerIndex)
+{
+ CHECK_LAYERS(graph, 0, layerIndex);
+ CHECK_LOCATION();
+
+ auto inputs = GetInputs(graph, layerIndex);
+ CHECK_VALID_SIZE(inputs.size(), 2);
+
+ auto outputs = GetOutputs(graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ auto fbLayer = graph->layers()->Get(layerIndex)->layer_as_ComparisonLayer();
+ auto fbDescriptor = fbLayer->descriptor();
+
+ armnn::ComparisonDescriptor descriptor;
+ descriptor.m_Operation = ToComparisonOperation(fbDescriptor->operation());
+
+ const std::string& layerName = GetLayerName(graph, layerIndex);
+ IConnectableLayer* layer = m_Network->AddComparisonLayer(descriptor, layerName.c_str());
+
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ RegisterInputSlots(graph, layerIndex, layer);
+ RegisterOutputSlots(graph, layerIndex, layer);
+}
+
void Deserializer::ParseConcat(GraphPtr graph, unsigned int layerIndex)
{
CHECK_LAYERS(graph, 0, layerIndex);
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index dda588d5e3..b951483926 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -83,6 +83,7 @@ private:
void ParseArgMinMax(GraphPtr graph, unsigned int layerIndex);
void ParseBatchToSpaceNd(GraphPtr graph, unsigned int layerIndex);
void ParseBatchNormalization(GraphPtr graph, unsigned int layerIndex);
+ void ParseComparison(GraphPtr graph, unsigned int layerIndex);
void ParseConcat(GraphPtr graph, unsigned int layerIndex);
void ParseConstant(GraphPtr graph, unsigned int layerIndex);
void ParseConvolution2d(GraphPtr graph, unsigned int layerIndex);
@@ -166,4 +167,4 @@ private:
std::unordered_map<unsigned int, Connections> m_GraphConnections;
};
-} //namespace armnnDeserializer
+} // namespace armnnDeserializer
diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md
index 3b35e5c9f4..fce7064337 100644
--- a/src/armnnDeserializer/DeserializerSupport.md
+++ b/src/armnnDeserializer/DeserializerSupport.md
@@ -13,6 +13,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
* BatchToSpaceNd
* BatchNormalization
* Concat
+* Comparison
* Constant
* Convolution2d
* DepthToSpace
@@ -20,11 +21,9 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
* Dequantize
* DetectionPostProcess
* Division
-* Equal
* Floor
* FullyConnected
* Gather
-* Greater
* Input
* InstanceNormalization
* L2Normalization
@@ -44,7 +43,6 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
* Quantize
* QuantizedLstm
* Reshape
-* ResizeBilinear
* Rsqrt
* Slice
* Softmax
@@ -59,3 +57,12 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
* Resize
More machine learning layers will be supported in future releases.
+
+## Deprecated layers
+
+Some layers have been deprecated and replaced by others layers. In order to maintain backward compatibility, serializations of these deprecated layers will deserialize to the layers that have replaced them, as follows:
+
+* Equal will deserialize as Comparison
+* Merger will deserialize as Concat
+* Greater will deserialize as Comparison
+* ResizeBilinear will deserialize as Resize