diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 46 |
1 files changed, 45 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 6143f4af6a..0aad048970 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -345,7 +345,9 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, case tflite::TensorType_INT32: type = armnn::DataType::Signed32; break; - + case tflite::TensorType_INT64: + type = armnn::DataType::Signed64; + break; default: { CheckLocation location = CHECK_LOCATION(); @@ -598,6 +600,7 @@ TfLiteParser::TfLiteParser(const Optional<ITfLiteParser::TfLiteParserOptions>& o m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE_CONV] = &TfLiteParser::ParseTransposeConv; m_ParserFunctions[tflite::BuiltinOperator_UNPACK] = &TfLiteParser::ParseUnpack; m_ParserFunctions[tflite::BuiltinOperator_DIV] = &TfLiteParser::ParseDiv; + m_ParserFunctions[tflite::BuiltinOperator_ARG_MAX] = &TfLiteParser::ParseArgMax; // register supported custom operators m_CustomParserFunctions["TFLite_Detection_PostProcess"] = &TfLiteParser::ParseDetectionPostProcess; } @@ -2847,6 +2850,47 @@ void TfLiteParser::ParseSplitV(size_t subgraphIndex, size_t operatorIndex) RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes); } +void TfLiteParser::ParseArgMax(size_t subgraphIndex, size_t operatorIndex) +{ + const auto &operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex]; + const auto *options = operatorPtr->builtin_options.AsArgMaxOptions(); + + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 2); + + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto layerName = boost::str(boost::format("ArgMax:%1%:%2%") % subgraphIndex % operatorIndex); + + armnn::TensorInfo sizeTensorInfo0 = ToTensorInfo(inputs[0]); + armnn::TensorInfo sizeTensorInfo1 = ToTensorInfo(inputs[1]); + + // Get const axis value from model and set it to descriptor. + BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer); + + ArgMinMaxDescriptor desc; + desc.m_Axis = axisBufferPtr->data.data()[0]; + // If output_type is int32 then set Signed32 else Signed64. Default type is Signed64. + desc.m_Output_Type = options->output_type == 3 ? armnn::DataType::Signed32 : armnn::DataType::Signed64; + desc.m_Function = ArgMinMaxFunction::Max; + + // Register a ArgMax layer. + IConnectableLayer *layer = m_Network->AddArgMinMaxLayer(desc, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + // Register input tensor to the layer. + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + + // Register output tensor to the layer. + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes); +} + armnn::IConnectableLayer* TfLiteParser::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer, unsigned int outputSlot, tflite::ActivationFunctionType activationType) |