aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp13
1 files changed, 8 insertions, 5 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 8286007b04..c4d2942779 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -2941,9 +2941,6 @@ void TfLiteParserImpl::ParseSplitV(size_t subgraphIndex, size_t operatorIndex)
void TfLiteParserImpl::ParseArgMax(size_t subgraphIndex, size_t operatorIndex)
{
- const auto &operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
- const auto *options = operatorPtr->builtin_options.AsArgMaxOptions();
-
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(inputs.size(), 2);
@@ -2961,14 +2958,20 @@ void TfLiteParserImpl::ParseArgMax(size_t subgraphIndex, size_t operatorIndex)
ArgMinMaxDescriptor desc;
desc.m_Axis = axisBufferPtr->data.data()[0];
- // If output_type is int32 then set Signed32 else Signed64. Default type is Signed64.
- desc.m_Output_Type = options->output_type == 3 ? armnn::DataType::Signed32 : armnn::DataType::Signed64;
desc.m_Function = ArgMinMaxFunction::Max;
// Register a ArgMax layer.
IConnectableLayer *layer = m_Network->AddArgMinMaxLayer(desc, layerName.c_str());
armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+ if (outputTensorInfo.GetDataType() != armnn::DataType::Signed32 &&
+ outputTensorInfo.GetDataType() != armnn::DataType::Signed64)
+ {
+ throw ParseException(
+ fmt::format(
+ "Output tensor data type is not supported. (Supported types: Signed32 & Signed64) {}",
+ CHECK_LOCATION().AsString()));
+ }
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
// Register input tensor to the layer.