aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp46
1 files changed, 45 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 6143f4af6a..0aad048970 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -345,7 +345,9 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr,
case tflite::TensorType_INT32:
type = armnn::DataType::Signed32;
break;
-
+ case tflite::TensorType_INT64:
+ type = armnn::DataType::Signed64;
+ break;
default:
{
CheckLocation location = CHECK_LOCATION();
@@ -598,6 +600,7 @@ TfLiteParser::TfLiteParser(const Optional<ITfLiteParser::TfLiteParserOptions>& o
m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE_CONV] = &TfLiteParser::ParseTransposeConv;
m_ParserFunctions[tflite::BuiltinOperator_UNPACK] = &TfLiteParser::ParseUnpack;
m_ParserFunctions[tflite::BuiltinOperator_DIV] = &TfLiteParser::ParseDiv;
+ m_ParserFunctions[tflite::BuiltinOperator_ARG_MAX] = &TfLiteParser::ParseArgMax;
// register supported custom operators
m_CustomParserFunctions["TFLite_Detection_PostProcess"] = &TfLiteParser::ParseDetectionPostProcess;
}
@@ -2847,6 +2850,47 @@ void TfLiteParser::ParseSplitV(size_t subgraphIndex, size_t operatorIndex)
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
}
+void TfLiteParser::ParseArgMax(size_t subgraphIndex, size_t operatorIndex)
+{
+ const auto &operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
+ const auto *options = operatorPtr->builtin_options.AsArgMaxOptions();
+
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(inputs.size(), 2);
+
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ auto layerName = boost::str(boost::format("ArgMax:%1%:%2%") % subgraphIndex % operatorIndex);
+
+ armnn::TensorInfo sizeTensorInfo0 = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo sizeTensorInfo1 = ToTensorInfo(inputs[1]);
+
+ // Get const axis value from model and set it to descriptor.
+ BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+
+ ArgMinMaxDescriptor desc;
+ desc.m_Axis = axisBufferPtr->data.data()[0];
+ // If output_type is int32 then set Signed32 else Signed64. Default type is Signed64.
+ desc.m_Output_Type = options->output_type == 3 ? armnn::DataType::Signed32 : armnn::DataType::Signed64;
+ desc.m_Function = ArgMinMaxFunction::Max;
+
+ // Register a ArgMax layer.
+ IConnectableLayer *layer = m_Network->AddArgMinMaxLayer(desc, layerName.c_str());
+
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ // Register input tensor to the layer.
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+
+ // Register output tensor to the layer.
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
+}
+
armnn::IConnectableLayer* TfLiteParser::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer,
unsigned int outputSlot,
tflite::ActivationFunctionType activationType)