aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruno Goncalves <bruno.slackware@gmail.com>2021-07-11 14:10:15 -0300
committerMatthew Sloyan <matthew.sloyan@arm.com>2021-07-21 09:28:23 +0000
commit2d0eb86a5756fb9402bd31d3f5adc5438305f676 (patch)
treecf454f1a82e79237646880e6a7dd76eda1167e1d
parenta34c98d17b6784e4b5b3a9d77654717e86517bab (diff)
downloadarmnn-2d0eb86a5756fb9402bd31d3f5adc5438305f676.tar.gz
Added comparison operators to TfLiteParser
E.g. Equal, NotEqual, Greater, GreaterOrEqual, Less and LessOrEqual Signed-off-by: Bruno Goncalves <bruno.slackware@gmail.com> Change-Id: Id56ef3cc19cc5c5daa19354010c9f25766e5fd00
-rw-r--r--CMakeLists.txt1
-rw-r--r--docs/01_01_parsers.dox6
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp69
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp7
-rw-r--r--src/armnnTfLiteParser/test/Comparison.cpp266
5 files changed, 349 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 212b90de7f..2e3af391af 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -662,6 +662,7 @@ if(BUILD_UNIT_TESTS)
src/armnnTfLiteParser/test/AvgPool2D.cpp
src/armnnTfLiteParser/test/BatchToSpaceND.cpp
src/armnnTfLiteParser/test/Cast.cpp
+ src/armnnTfLiteParser/test/Comparison.cpp
src/armnnTfLiteParser/test/Concatenation.cpp
src/armnnTfLiteParser/test/Constant.cpp
src/armnnTfLiteParser/test/Conv2D.cpp
diff --git a/docs/01_01_parsers.dox b/docs/01_01_parsers.dox
index 761380c939..63c34210fe 100644
--- a/docs/01_01_parsers.dox
+++ b/docs/01_01_parsers.dox
@@ -117,11 +117,16 @@ The Arm NN SDK TensorFlow Lite parser currently supports the following operators
- DEQUANTIZE
- DIV
- ELU
+- EQUAL
- EXP
- FULLY_CONNECTED, Supported Fused Activation: RELU , RELU6 , TANH, NONE
- GATHER
+- GREATER
+- GREATER_EQUAL
- HARD_SWISH
- LEAKY_RELU
+- LESS
+- LESS_EQUAL
- LOGICAL_NOT
- LOGISTIC
- L2_NORMALIZATION
@@ -131,6 +136,7 @@ The Arm NN SDK TensorFlow Lite parser currently supports the following operators
- MINIMUM
- MUL
- NEG
+- NOT_EQUAL
- PACK
- PAD
- PRELU
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 2df47eb198..410f452ff1 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -622,12 +622,17 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt
m_ParserFunctions[tflite::BuiltinOperator_DEQUANTIZE] = &TfLiteParserImpl::ParseDequantize;
m_ParserFunctions[tflite::BuiltinOperator_DIV] = &TfLiteParserImpl::ParseDiv;
m_ParserFunctions[tflite::BuiltinOperator_ELU] = &TfLiteParserImpl::ParseElu;
+ m_ParserFunctions[tflite::BuiltinOperator_EQUAL] = &TfLiteParserImpl::ParseEqual;
m_ParserFunctions[tflite::BuiltinOperator_EXP] = &TfLiteParserImpl::ParseExp;
m_ParserFunctions[tflite::BuiltinOperator_EXPAND_DIMS] = &TfLiteParserImpl::ParseExpandDims;
m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED] = &TfLiteParserImpl::ParseFullyConnected;
m_ParserFunctions[tflite::BuiltinOperator_GATHER] = &TfLiteParserImpl::ParseGather;
+ m_ParserFunctions[tflite::BuiltinOperator_GREATER] = &TfLiteParserImpl::ParseGreater;
+ m_ParserFunctions[tflite::BuiltinOperator_GREATER_EQUAL] = &TfLiteParserImpl::ParseGreaterOrEqual;
m_ParserFunctions[tflite::BuiltinOperator_HARD_SWISH] = &TfLiteParserImpl::ParseHardSwish;
m_ParserFunctions[tflite::BuiltinOperator_LEAKY_RELU] = &TfLiteParserImpl::ParseLeakyRelu;
+ m_ParserFunctions[tflite::BuiltinOperator_LESS] = &TfLiteParserImpl::ParseLess;
+ m_ParserFunctions[tflite::BuiltinOperator_LESS_EQUAL] = &TfLiteParserImpl::ParseLessOrEqual;
m_ParserFunctions[tflite::BuiltinOperator_LOGICAL_NOT] = &TfLiteParserImpl::ParseLogicalNot;
m_ParserFunctions[tflite::BuiltinOperator_LOGISTIC] = &TfLiteParserImpl::ParseLogistic;
m_ParserFunctions[tflite::BuiltinOperator_L2_NORMALIZATION] = &TfLiteParserImpl::ParseL2Normalization;
@@ -637,6 +642,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt
m_ParserFunctions[tflite::BuiltinOperator_MINIMUM] = &TfLiteParserImpl::ParseMinimum;
m_ParserFunctions[tflite::BuiltinOperator_MUL] = &TfLiteParserImpl::ParseMul;
m_ParserFunctions[tflite::BuiltinOperator_NEG] = &TfLiteParserImpl::ParseNeg;
+ m_ParserFunctions[tflite::BuiltinOperator_NOT_EQUAL] = &TfLiteParserImpl::ParseNotEqual;
m_ParserFunctions[tflite::BuiltinOperator_PACK] = &TfLiteParserImpl::ParsePack;
m_ParserFunctions[tflite::BuiltinOperator_PAD] = &TfLiteParserImpl::ParsePad;
m_ParserFunctions[tflite::BuiltinOperator_PRELU] = &TfLiteParserImpl::ParsePrelu;
@@ -3373,6 +3379,69 @@ void TfLiteParserImpl::ParseElementwiseUnary(size_t subgraphIndex, size_t operat
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
}
+void TfLiteParserImpl::ParseEqual(size_t subgraphIndex, size_t operatorIndex)
+{
+ ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::Equal);
+}
+
+void TfLiteParserImpl::ParseNotEqual(size_t subgraphIndex, size_t operatorIndex)
+{
+ ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::NotEqual);
+}
+
+void TfLiteParserImpl::ParseGreater(size_t subgraphIndex, size_t operatorIndex)
+{
+ ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::Greater);
+}
+
+void TfLiteParserImpl::ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex)
+{
+ ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::GreaterOrEqual);
+}
+
+void TfLiteParserImpl::ParseLess(size_t subgraphIndex, size_t operatorIndex)
+{
+ ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::Less);
+}
+
+void TfLiteParserImpl::ParseLessOrEqual(size_t subgraphIndex, size_t operatorIndex)
+{
+ ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::LessOrEqual);
+}
+
+void TfLiteParserImpl::ParseComparison(size_t subgraphIndex, size_t operatorIndex,
+ ComparisonOperation comparisonOperation)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(inputs.size(), 2);
+
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ auto layerName = std::string(GetComparisonOperationAsCString(comparisonOperation)) + ":{}:{}";
+ std::string layerNameFormatted = fmt::format(layerName, subgraphIndex, operatorIndex);
+
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo input1TensorInfo = ToTensorInfo(inputs[1]);
+ CheckMatchingQuantization(inputTensorInfo, input1TensorInfo, layerNameFormatted, "Input 0", "Input 1");
+
+ ComparisonDescriptor desc;
+ desc.m_Operation = comparisonOperation;
+ IConnectableLayer* layer = m_Network->AddComparisonLayer(desc, layerNameFormatted.c_str());
+ ARMNN_ASSERT(layer != nullptr);
+
+ TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
+}
+
armnn::IConnectableLayer* TfLiteParserImpl::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer,
unsigned int outputSlot,
tflite::ActivationFunctionType activationType)
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 49ccd2705c..e601540fb1 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -106,6 +106,7 @@ private:
void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex);
void ParseCast(size_t subgraphIndex, size_t operatorIndex);
+ void ParseComparison(size_t subgraphIndex, size_t operatorIndex, armnn::ComparisonOperation comparisonOperation);
void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
void ParseDepthToSpace(size_t subgraphIndex, size_t operatorIndex);
@@ -115,12 +116,17 @@ private:
void ParseDiv(size_t subgraphIndex, size_t operatorIndex);
void ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, armnn::UnaryOperation unaryOperation);
void ParseElu(size_t subgraphIndex, size_t operatorIndex);
+ void ParseEqual(size_t subgraphIndex, size_t operatorIndex);
void ParseExp(size_t subgraphIndex, size_t operatorIndex);
void ParseExpandDims(size_t subgraphIndex, size_t operatorIndex);
void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
void ParseGather(size_t subgraphIndex, size_t operatorIndex);
+ void ParseGreater(size_t subgraphIndex, size_t operatorIndex);
+ void ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex);
void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex);
void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex);
+ void ParseLess(size_t subgraphIndex, size_t operatorIndex);
+ void ParseLessOrEqual(size_t subgraphIndex, size_t operatorIndex);
void ParseLogicalNot(size_t subgraphIndex, size_t operatorIndex);
void ParseLogistic(size_t subgraphIndex, size_t operatorIndex);
void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex);
@@ -130,6 +136,7 @@ private:
void ParseMinimum(size_t subgraphIndex, size_t operatorIndex);
void ParseMul(size_t subgraphIndex, size_t operatorIndex);
void ParseNeg(size_t subgraphIndex, size_t operatorIndex);
+ void ParseNotEqual(size_t subgraphIndex, size_t operatorIndex);
void ParsePack(size_t subgraphIndex, size_t operatorIndex);
void ParsePad(size_t subgraphIndex, size_t operatorIndex);
void ParsePool(size_t subgraphIndex, size_t operatorIndex, armnn::PoolingAlgorithm algorithm);
diff --git a/src/armnnTfLiteParser/test/Comparison.cpp b/src/armnnTfLiteParser/test/Comparison.cpp
new file mode 100644
index 0000000000..5ae194be27
--- /dev/null
+++ b/src/armnnTfLiteParser/test/Comparison.cpp
@@ -0,0 +1,266 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ParserFlatbuffersFixture.hpp"
+#include "../TfLiteParser.hpp"
+
+#include <string>
+
+TEST_SUITE("TensorflowLiteParser_Comparison")
+{
+struct ComparisonFixture : public ParserFlatbuffersFixture
+{
+ explicit ComparisonFixture(const std::string& operatorCode,
+ const std::string& dataType,
+ const std::string& inputShape,
+ const std::string& inputShape2,
+ const std::string& outputShape)
+ {
+ m_JsonString = R"(
+ {
+ "version": 3,
+ "operator_codes": [ { "builtin_code": )" + operatorCode + R"( } ],
+ "subgraphs": [ {
+ "tensors": [
+ {
+ "shape": )" + inputShape + R"(,
+ "type": )" + dataType + R"( ,
+ "buffer": 0,
+ "name": "inputTensor",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ },
+ {
+ "shape": )" + inputShape2 + R"(,
+ "type": )" + dataType + R"( ,
+ "buffer": 1,
+ "name": "inputTensor2",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ },
+ {
+ "shape": )" + outputShape + R"( ,
+ "type": "BOOL",
+ "buffer": 2,
+ "name": "outputTensor",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ }
+ ],
+ "inputs": [ 0, 1 ],
+ "outputs": [ 2 ],
+ "operators": [
+ {
+ "opcode_index": 0,
+ "inputs": [ 0, 1 ],
+ "outputs": [ 2 ],
+ "custom_options_format": "FLEXBUFFERS"
+ }
+ ],
+ } ],
+ "buffers" : [
+ { },
+ { }
+ ]
+ }
+ )";
+ Setup();
+ }
+};
+
+struct SimpleEqualFixture : public ComparisonFixture
+{
+ SimpleEqualFixture() : ComparisonFixture("EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(SimpleEqualFixture, "SimpleEqual")
+{
+ RunTest<2, armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>(
+ 0,
+ {{"inputTensor", { 0, 1, 2, 3 }},
+ {"inputTensor2", { 0, 1, 5, 6 }}},
+ {{"outputTensor", { 1, 1, 0, 0 }}});
+}
+
+struct BroadcastEqualFixture : public ComparisonFixture
+{
+ BroadcastEqualFixture() : ComparisonFixture("EQUAL", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(BroadcastEqualFixture, "BroadcastEqual")
+{
+ RunTest<2, armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>(
+ 0,
+ {{"inputTensor", { 0, 1, 2, 3 }},
+ {"inputTensor2", { 0, 1 }}},
+ {{"outputTensor", { 1, 1, 0, 0 }}});
+}
+
+struct SimpleNotEqualFixture : public ComparisonFixture
+{
+ SimpleNotEqualFixture() : ComparisonFixture("NOT_EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(SimpleNotEqualFixture, "SimpleNotEqual")
+{
+ RunTest<2, armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>(
+ 0,
+ {{"inputTensor", { 0, 1, 2, 3 }},
+ {"inputTensor2", { 0, 1, 5, 6 }}},
+ {{"outputTensor", { 0, 0, 1, 1 }}});
+}
+
+struct BroadcastNotEqualFixture : public ComparisonFixture
+{
+ BroadcastNotEqualFixture() : ComparisonFixture("NOT_EQUAL", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(BroadcastNotEqualFixture, "BroadcastNotEqual")
+{
+ RunTest<2, armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>(
+ 0,
+ {{"inputTensor", { 0, 1, 2, 3 }},
+ {"inputTensor2", { 0, 1 }}},
+ {{"outputTensor", { 0, 0, 1, 1 }}});
+}
+
+struct SimpleGreaterFixture : public ComparisonFixture
+{
+ SimpleGreaterFixture() : ComparisonFixture("GREATER", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(SimpleGreaterFixture, "SimpleGreater")
+{
+ RunTest<2, armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>(
+ 0,
+ {{"inputTensor", { 0, 2, 3, 6 }},
+ {"inputTensor2", { 0, 1, 5, 3 }}},
+ {{"outputTensor", { 0, 1, 0, 1 }}});
+}
+
+struct BroadcastGreaterFixture : public ComparisonFixture
+{
+ BroadcastGreaterFixture() : ComparisonFixture("GREATER", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(BroadcastGreaterFixture, "BroadcastGreater")
+{
+ RunTest<2, armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>(
+ 0,
+ {{"inputTensor", { 5, 4, 1, 0 }},
+ {"inputTensor2", { 2, 3 }}},
+ {{"outputTensor", { 1, 1, 0, 0 }}});
+}
+
+struct SimpleGreaterOrEqualFixture : public ComparisonFixture
+{
+ SimpleGreaterOrEqualFixture() : ComparisonFixture("GREATER_EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(SimpleGreaterOrEqualFixture, "SimpleGreaterOrEqual")
+{
+ RunTest<2, armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>(
+ 0,
+ {{"inputTensor", { 0, 2, 3, 6 }},
+ {"inputTensor2", { 0, 1, 5, 3 }}},
+ {{"outputTensor", { 1, 1, 0, 1 }}});
+}
+
+struct BroadcastGreaterOrEqualFixture : public ComparisonFixture
+{
+ BroadcastGreaterOrEqualFixture() : ComparisonFixture("GREATER_EQUAL", "UINT8",
+ "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(BroadcastGreaterOrEqualFixture, "BroadcastGreaterOrEqual")
+{
+ RunTest<2, armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>(
+ 0,
+ {{"inputTensor", { 5, 4, 1, 0 }},
+ {"inputTensor2", { 2, 4 }}},
+ {{"outputTensor", { 1, 1, 0, 0 }}});
+}
+
+struct SimpleLessFixture : public ComparisonFixture
+{
+ SimpleLessFixture() : ComparisonFixture("LESS", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(SimpleLessFixture, "SimpleLess")
+{
+ RunTest<2, armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>(
+ 0,
+ {{"inputTensor", { 0, 2, 3, 6 }},
+ {"inputTensor2", { 0, 1, 5, 3 }}},
+ {{"outputTensor", { 0, 0, 1, 0 }}});
+}
+
+struct BroadcastLessFixture : public ComparisonFixture
+{
+ BroadcastLessFixture() : ComparisonFixture("LESS", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(BroadcastLessFixture, "BroadcastLess")
+{
+ RunTest<2, armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>(
+ 0,
+ {{"inputTensor", { 5, 4, 1, 0 }},
+ {"inputTensor2", { 2, 3 }}},
+ {{"outputTensor", { 0, 0, 1, 1 }}});
+}
+
+struct SimpleLessOrEqualFixture : public ComparisonFixture
+{
+ SimpleLessOrEqualFixture() : ComparisonFixture("LESS_EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(SimpleLessOrEqualFixture, "SimpleLessOrEqual")
+{
+ RunTest<2, armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>(
+ 0,
+ {{"inputTensor", { 0, 2, 3, 6 }},
+ {"inputTensor2", { 0, 1, 5, 3 }}},
+ {{"outputTensor", { 1, 0, 1, 0 }}});
+}
+
+struct BroadcastLessOrEqualFixture : public ComparisonFixture
+{
+ BroadcastLessOrEqualFixture() : ComparisonFixture("LESS_EQUAL", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
+};
+
+TEST_CASE_FIXTURE(BroadcastLessOrEqualFixture, "BroadcastLessOrEqual")
+{
+ RunTest<2, armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>(
+ 0,
+ {{"inputTensor", { 5, 4, 1, 0 }},
+ {"inputTensor2", { 1, 3 }}},
+ {{"outputTensor", { 0, 0, 1, 1 }}});
+}
+
+}