diff options
author | Mike Kelly <mike.kelly@arm.com> | 2021-04-06 12:25:55 +0100 |
---|---|---|
committer | Matthew Sloyan <matthew.sloyan@arm.com> | 2021-04-09 12:38:48 +0100 |
commit | 1f140f7226c4ed7bc5cbaf2ce09654eee452f4bf (patch) | |
tree | 044d491d429b6da4d85530afd4ea2a310cdbb827 /src/armnnTfLiteParser/TfLiteParser.cpp | |
parent | 7c67fabc86b6647855beebac9f6cfe92341357cb (diff) | |
download | armnn-1f140f7226c4ed7bc5cbaf2ce09654eee452f4bf.tar.gz |
MLCE-328 Serializer/Deserializer does not support Signed64
* Added support for Signed64 to flatbuffer's schema & updated source tree
* Added support for Signed64 to TFLite Delegate
* Added support for Signed64 to Serializer
* Added support for Signed64 to Deserializer
* Added unit test for ArgMinMax to Deserializer
* Deprecated m_Output_Type from the ArgMinMaxDescriptor: the output type
is solely determined by the DataType of the output Tensor
* Fixed issue where RefArgMinMaxWorkload could output data using
the wrong DataType
* Added Signed64 to RefLayerSupport::IsArgMinMaxSupported as a supported
type
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Change-Id: Ib622c052a1f8aa3e658262f8bde5a6881a8cbe10
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 8286007b04..c4d2942779 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -2941,9 +2941,6 @@ void TfLiteParserImpl::ParseSplitV(size_t subgraphIndex, size_t operatorIndex) void TfLiteParserImpl::ParseArgMax(size_t subgraphIndex, size_t operatorIndex) { - const auto &operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex]; - const auto *options = operatorPtr->builtin_options.AsArgMaxOptions(); - CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(inputs.size(), 2); @@ -2961,14 +2958,20 @@ void TfLiteParserImpl::ParseArgMax(size_t subgraphIndex, size_t operatorIndex) ArgMinMaxDescriptor desc; desc.m_Axis = axisBufferPtr->data.data()[0]; - // If output_type is int32 then set Signed32 else Signed64. Default type is Signed64. - desc.m_Output_Type = options->output_type == 3 ? armnn::DataType::Signed32 : armnn::DataType::Signed64; desc.m_Function = ArgMinMaxFunction::Max; // Register a ArgMax layer. IConnectableLayer *layer = m_Network->AddArgMinMaxLayer(desc, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + if (outputTensorInfo.GetDataType() != armnn::DataType::Signed32 && + outputTensorInfo.GetDataType() != armnn::DataType::Signed64) + { + throw ParseException( + fmt::format( + "Output tensor data type is not supported. (Supported types: Signed32 & Signed64) {}", + CHECK_LOCATION().AsString())); + } layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); // Register input tensor to the layer. |