aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp70
1 files changed, 53 insertions, 17 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index c4d2942779..cb3426ea5b 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -604,6 +604,8 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt
{
// register supported operators
m_ParserFunctions[tflite::BuiltinOperator_ADD] = &TfLiteParserImpl::ParseAdd;
+ m_ParserFunctions[tflite::BuiltinOperator_ARG_MIN] = &TfLiteParserImpl::ParseArgMin;
+ m_ParserFunctions[tflite::BuiltinOperator_ARG_MAX] = &TfLiteParserImpl::ParseArgMax;
m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D] = &TfLiteParserImpl::ParseAveragePool2D;
m_ParserFunctions[tflite::BuiltinOperator_BATCH_TO_SPACE_ND] = &TfLiteParserImpl::ParseBatchToSpaceND;
m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION] = &TfLiteParserImpl::ParseConcatenation;
@@ -612,6 +614,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt
m_ParserFunctions[tflite::BuiltinOperator_DEPTH_TO_SPACE] = &TfLiteParserImpl::ParseDepthToSpace;
m_ParserFunctions[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] = &TfLiteParserImpl::ParseDepthwiseConv2D;
m_ParserFunctions[tflite::BuiltinOperator_DEQUANTIZE] = &TfLiteParserImpl::ParseDequantize;
+ m_ParserFunctions[tflite::BuiltinOperator_DIV] = &TfLiteParserImpl::ParseDiv;
m_ParserFunctions[tflite::BuiltinOperator_ELU] = &TfLiteParserImpl::ParseElu;
m_ParserFunctions[tflite::BuiltinOperator_EXP] = &TfLiteParserImpl::ParseExp;
m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED] = &TfLiteParserImpl::ParseFullyConnected;
@@ -649,8 +652,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt
m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE] = &TfLiteParserImpl::ParseTranspose;
m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE_CONV] = &TfLiteParserImpl::ParseTransposeConv;
m_ParserFunctions[tflite::BuiltinOperator_UNPACK] = &TfLiteParserImpl::ParseUnpack;
- m_ParserFunctions[tflite::BuiltinOperator_DIV] = &TfLiteParserImpl::ParseDiv;
- m_ParserFunctions[tflite::BuiltinOperator_ARG_MAX] = &TfLiteParserImpl::ParseArgMax;
+
// register supported custom operators
m_CustomParserFunctions["TFLite_Detection_PostProcess"] = &TfLiteParserImpl::ParseDetectionPostProcess;
}
@@ -2939,8 +2941,18 @@ void TfLiteParserImpl::ParseSplitV(size_t subgraphIndex, size_t operatorIndex)
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
}
+void TfLiteParserImpl::ParseArgMin(size_t subgraphIndex, size_t operatorIndex)
+{
+ ParseArgMinMax(subgraphIndex, operatorIndex, armnn::ArgMinMaxFunction::Min);
+}
+
void TfLiteParserImpl::ParseArgMax(size_t subgraphIndex, size_t operatorIndex)
{
+ ParseArgMinMax(subgraphIndex, operatorIndex, armnn::ArgMinMaxFunction::Max);
+}
+
+void TfLiteParserImpl::ParseArgMinMax(size_t subgraphIndex, size_t operatorIndex, ArgMinMaxFunction argMinMaxFunction)
+{
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(inputs.size(), 2);
@@ -2948,22 +2960,11 @@ void TfLiteParserImpl::ParseArgMax(size_t subgraphIndex, size_t operatorIndex)
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
- auto layerName = fmt::format("ArgMax:{}:{}", subgraphIndex, operatorIndex);
-
- armnn::TensorInfo sizeTensorInfo0 = ToTensorInfo(inputs[0]);
- armnn::TensorInfo sizeTensorInfo1 = ToTensorInfo(inputs[1]);
-
- // Get const axis value from model and set it to descriptor.
- BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
-
- ArgMinMaxDescriptor desc;
- desc.m_Axis = axisBufferPtr->data.data()[0];
- desc.m_Function = ArgMinMaxFunction::Max;
-
- // Register a ArgMax layer.
- IConnectableLayer *layer = m_Network->AddArgMinMaxLayer(desc, layerName.c_str());
-
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo axisTensorInfo = ToTensorInfo(inputs[1]);
armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+
+ // Check if output tensor type is Signed32 or Signed64
if (outputTensorInfo.GetDataType() != armnn::DataType::Signed32 &&
outputTensorInfo.GetDataType() != armnn::DataType::Signed64)
{
@@ -2972,6 +2973,41 @@ void TfLiteParserImpl::ParseArgMax(size_t subgraphIndex, size_t operatorIndex)
"Output tensor data type is not supported. (Supported types: Signed32 & Signed64) {}",
CHECK_LOCATION().AsString()));
}
+
+ // Get const axis value from model and set it to descriptor.
+ BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+ if (axisBufferPtr == nullptr)
+ {
+ throw ParseException(
+ fmt::format("Operation has invalid inputs. Failed to read axis. {}",
+ CHECK_LOCATION().AsString()));
+ }
+
+ std::vector<int32_t> axisData(axisTensorInfo.GetNumElements());
+ ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes());
+ int32_t axis = axisData.front();
+
+ auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
+ if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
+ {
+ // Square bracket denotes inclusive n while parenthesis denotes exclusive n
+ // E.g. Rank 4 tensor can have axis in range [-4, 3)
+ // -1 == 3, -2 == 2, -3 == 1, -4 == 0
+ throw ParseException(
+ fmt::format("Operation has invalid axis: {}. Axis must be in range [-n, n) {}",
+ axis,
+ CHECK_LOCATION().AsString()));
+ }
+
+ ArgMinMaxDescriptor desc;
+ desc.m_Axis = axis;
+ desc.m_Function = argMinMaxFunction;
+
+ // Register a ArgMin/ArgMax layer.
+ auto layerName = argMinMaxFunction == ArgMinMaxFunction::Max ? "ArgMax:{}:{}" : "ArgMin:{}:{}";
+ auto layerNameFormatted = fmt::format(layerName, subgraphIndex, operatorIndex);
+ IConnectableLayer *layer = m_Network->AddArgMinMaxLayer(desc, layerNameFormatted.c_str());
+ ARMNN_ASSERT(layer != nullptr);
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
// Register input tensor to the layer.