diff options
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 8286007b04..c4d2942779 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -2941,9 +2941,6 @@ void TfLiteParserImpl::ParseSplitV(size_t subgraphIndex, size_t operatorIndex) void TfLiteParserImpl::ParseArgMax(size_t subgraphIndex, size_t operatorIndex) { - const auto &operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex]; - const auto *options = operatorPtr->builtin_options.AsArgMaxOptions(); - CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(inputs.size(), 2); @@ -2961,14 +2958,20 @@ void TfLiteParserImpl::ParseArgMax(size_t subgraphIndex, size_t operatorIndex) ArgMinMaxDescriptor desc; desc.m_Axis = axisBufferPtr->data.data()[0]; - // If output_type is int32 then set Signed32 else Signed64. Default type is Signed64. - desc.m_Output_Type = options->output_type == 3 ? armnn::DataType::Signed32 : armnn::DataType::Signed64; desc.m_Function = ArgMinMaxFunction::Max; // Register a ArgMax layer. IConnectableLayer *layer = m_Network->AddArgMinMaxLayer(desc, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + if (outputTensorInfo.GetDataType() != armnn::DataType::Signed32 && + outputTensorInfo.GetDataType() != armnn::DataType::Signed64) + { + throw ParseException( + fmt::format( + "Output tensor data type is not supported. (Supported types: Signed32 & Signed64) {}", + CHECK_LOCATION().AsString())); + } layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); // Register input tensor to the layer. |