aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp56
1 files changed, 50 insertions, 6 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index f38f45fcdf..2df47eb198 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -623,6 +623,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt
m_ParserFunctions[tflite::BuiltinOperator_DIV] = &TfLiteParserImpl::ParseDiv;
m_ParserFunctions[tflite::BuiltinOperator_ELU] = &TfLiteParserImpl::ParseElu;
m_ParserFunctions[tflite::BuiltinOperator_EXP] = &TfLiteParserImpl::ParseExp;
+ m_ParserFunctions[tflite::BuiltinOperator_EXPAND_DIMS] = &TfLiteParserImpl::ParseExpandDims;
m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED] = &TfLiteParserImpl::ParseFullyConnected;
m_ParserFunctions[tflite::BuiltinOperator_GATHER] = &TfLiteParserImpl::ParseGather;
m_ParserFunctions[tflite::BuiltinOperator_HARD_SWISH] = &TfLiteParserImpl::ParseHardSwish;
@@ -1091,6 +1092,37 @@ void TfLiteParserImpl::ParseDequantize(size_t subgraphIndex, size_t operatorInde
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
}
+void TfLiteParserImpl::ParseExpandDims(size_t subgraphIndex, size_t operatorIndex)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(inputs.size(), 2);
+
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ auto layerName = fmt::format("ExpandDims:{}:{}", subgraphIndex, operatorIndex);
+
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true);
+
+ CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
+
+ ReshapeDescriptor reshapeDesc;
+ reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
+
+ IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
+ ARMNN_ASSERT(layer != nullptr);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
+}
+
void TfLiteParserImpl::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
{
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
@@ -1586,11 +1618,10 @@ void TfLiteParserImpl::ParseSpaceToBatchND(size_t subgraphIndex, size_t operator
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
}
-armnn::TensorInfo TfLiteParserImpl::OutputShapeOfSqueeze(const std::vector<uint32_t> & squeezeDimsIn,
+armnn::TensorInfo TfLiteParserImpl::OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims,
const armnn::TensorInfo & inputTensorInfo)
{
- CHECK_VALID_SIZE(squeezeDimsIn.size(), 0, 1, 2, 3, 4);
- std::vector<uint32_t> squeezeDims = squeezeDimsIn;
+ CHECK_VALID_SIZE(squeezeDims.size(), 0, 1, 2, 3, 4);
static const uint32_t dimensionSequence[] = { 0, 1, 2, 3 };
if (inputTensorInfo.GetNumDimensions() > 4)
@@ -1688,9 +1719,22 @@ void TfLiteParserImpl::ParseSqueeze(size_t subgraphIndex, size_t operatorIndex)
auto layerName = fmt::format("Squeeze:{}:{}", subgraphIndex, operatorIndex);
armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
- armnn::TensorInfo outputTensorInfo =
- TfLiteParserImpl::OutputShapeOfSqueeze(AsUnsignedVector(options->squeeze_dims),
- inputTensorInfo);
+
+ std::vector<uint32_t> squeezeDim;
+ // A single negative dim index is interpreted as a negative index in python
+ // Meaning the index will be the shape size plus the negative index value
+ if (options->squeeze_dims.size() == 1 && options->squeeze_dims[0] < 0)
+ {
+ int32_t dim = static_cast<int32_t>(inputTensorInfo.GetShape().GetNumDimensions()) + options->squeeze_dims[0];
+ squeezeDim.push_back(static_cast<uint32_t>(dim));
+ }
+ else
+ {
+ squeezeDim = AsUnsignedVector(options->squeeze_dims);
+ }
+
+ armnn::TensorInfo outputTensorInfo = TfLiteParserImpl::OutputShapeOfSqueeze(squeezeDim, inputTensorInfo);
+
CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
ReshapeDescriptor reshapeDesc;