diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 2df47eb198..410f452ff1 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -622,12 +622,17 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt m_ParserFunctions[tflite::BuiltinOperator_DEQUANTIZE] = &TfLiteParserImpl::ParseDequantize; m_ParserFunctions[tflite::BuiltinOperator_DIV] = &TfLiteParserImpl::ParseDiv; m_ParserFunctions[tflite::BuiltinOperator_ELU] = &TfLiteParserImpl::ParseElu; + m_ParserFunctions[tflite::BuiltinOperator_EQUAL] = &TfLiteParserImpl::ParseEqual; m_ParserFunctions[tflite::BuiltinOperator_EXP] = &TfLiteParserImpl::ParseExp; m_ParserFunctions[tflite::BuiltinOperator_EXPAND_DIMS] = &TfLiteParserImpl::ParseExpandDims; m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED] = &TfLiteParserImpl::ParseFullyConnected; m_ParserFunctions[tflite::BuiltinOperator_GATHER] = &TfLiteParserImpl::ParseGather; + m_ParserFunctions[tflite::BuiltinOperator_GREATER] = &TfLiteParserImpl::ParseGreater; + m_ParserFunctions[tflite::BuiltinOperator_GREATER_EQUAL] = &TfLiteParserImpl::ParseGreaterOrEqual; m_ParserFunctions[tflite::BuiltinOperator_HARD_SWISH] = &TfLiteParserImpl::ParseHardSwish; m_ParserFunctions[tflite::BuiltinOperator_LEAKY_RELU] = &TfLiteParserImpl::ParseLeakyRelu; + m_ParserFunctions[tflite::BuiltinOperator_LESS] = &TfLiteParserImpl::ParseLess; + m_ParserFunctions[tflite::BuiltinOperator_LESS_EQUAL] = &TfLiteParserImpl::ParseLessOrEqual; m_ParserFunctions[tflite::BuiltinOperator_LOGICAL_NOT] = &TfLiteParserImpl::ParseLogicalNot; m_ParserFunctions[tflite::BuiltinOperator_LOGISTIC] = &TfLiteParserImpl::ParseLogistic; m_ParserFunctions[tflite::BuiltinOperator_L2_NORMALIZATION] = &TfLiteParserImpl::ParseL2Normalization; @@ -637,6 +642,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt m_ParserFunctions[tflite::BuiltinOperator_MINIMUM] = &TfLiteParserImpl::ParseMinimum; m_ParserFunctions[tflite::BuiltinOperator_MUL] = &TfLiteParserImpl::ParseMul; m_ParserFunctions[tflite::BuiltinOperator_NEG] = &TfLiteParserImpl::ParseNeg; + m_ParserFunctions[tflite::BuiltinOperator_NOT_EQUAL] = &TfLiteParserImpl::ParseNotEqual; m_ParserFunctions[tflite::BuiltinOperator_PACK] = &TfLiteParserImpl::ParsePack; m_ParserFunctions[tflite::BuiltinOperator_PAD] = &TfLiteParserImpl::ParsePad; m_ParserFunctions[tflite::BuiltinOperator_PRELU] = &TfLiteParserImpl::ParsePrelu; @@ -3373,6 +3379,69 @@ void TfLiteParserImpl::ParseElementwiseUnary(size_t subgraphIndex, size_t operat RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes); } +void TfLiteParserImpl::ParseEqual(size_t subgraphIndex, size_t operatorIndex) +{ + ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::Equal); +} + +void TfLiteParserImpl::ParseNotEqual(size_t subgraphIndex, size_t operatorIndex) +{ + ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::NotEqual); +} + +void TfLiteParserImpl::ParseGreater(size_t subgraphIndex, size_t operatorIndex) +{ + ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::Greater); +} + +void TfLiteParserImpl::ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex) +{ + ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::GreaterOrEqual); +} + +void TfLiteParserImpl::ParseLess(size_t subgraphIndex, size_t operatorIndex) +{ + ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::Less); +} + +void TfLiteParserImpl::ParseLessOrEqual(size_t subgraphIndex, size_t operatorIndex) +{ + ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::LessOrEqual); +} + +void TfLiteParserImpl::ParseComparison(size_t subgraphIndex, size_t operatorIndex, + ComparisonOperation comparisonOperation) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 2); + + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto layerName = std::string(GetComparisonOperationAsCString(comparisonOperation)) + ":{}:{}"; + std::string layerNameFormatted = fmt::format(layerName, subgraphIndex, operatorIndex); + + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + armnn::TensorInfo input1TensorInfo = ToTensorInfo(inputs[1]); + CheckMatchingQuantization(inputTensorInfo, input1TensorInfo, layerNameFormatted, "Input 0", "Input 1"); + + ComparisonDescriptor desc; + desc.m_Operation = comparisonOperation; + IConnectableLayer* layer = m_Network->AddComparisonLayer(desc, layerNameFormatted.c_str()); + ARMNN_ASSERT(layer != nullptr); + + TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]}); + + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); +} + armnn::IConnectableLayer* TfLiteParserImpl::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer, unsigned int outputSlot, tflite::ActivationFunctionType activationType) |