aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2023-06-14 14:51:17 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2023-06-22 12:12:31 +0000
commita7a605a7554158a47127b1988b94c6bfac0a64b9 (patch)
tree9dce948707b266e45b0a5a6974ca940b54ce7f42 /src/armnnTfLiteParser/TfLiteParser.cpp
parenta9c5c16154386fb1ec924c3778caf6b027bb57ee (diff)
downloadarmnn-a7a605a7554158a47127b1988b94c6bfac0a64b9.tar.gz
IVGCVSW-7785 BugFix: ExpandDims not working when batch!=1
* This commit fixes partially the ticket. In ToTensorInfo() we assume batch is 1 when it is unknown. We call OutputTensorInfoFromInputs() to amend this assumption/ However, this does not work for reshape layer. Therefore, we have to calculate the output shape in the ParseExpandDims(). Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: Iedc32a44b4ec0d8b7d2cc0b08f38f0776402f7bd
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp66
1 files changed, 33 insertions, 33 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 8b2d7a25bc..6354a1e13a 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1509,57 +1509,57 @@ void TfLiteParserImpl::ParseExpandDims(size_t subgraphIndex, size_t operatorInde
armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true);
-
CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
- ReshapeDescriptor reshapeDesc;
+ armnn::TensorInfo axisTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
- if (outputTensorInfo.GetShape().AreAllDimensionsSpecified())
+ BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+ if (axisBufferPtr == nullptr)
{
- reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
+ throw ParseException(fmt::format("{}: Operation has invalid inputs. Failed to read axis.",
+ CHECK_LOCATION().AsString()));
}
- else
- {
- int32_t axis = inputs[1]->shape[0];
- int32_t inputDimSize = static_cast<int32_t>(inputTensorInfo.GetShape().GetNumDimensions());
+ std::vector<int32_t> axisData(axisTensorInfo.GetNumElements());
+ ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes());
+ int32_t axis = axisData[0];
- if (axis > inputDimSize || axis < 0 - (inputDimSize + 1))
- {
- throw ParseException("axis must be in range [0 - (inputDimSize + 1), inputDimSize] inclusive");
- }
+ auto inputRank = static_cast<int32_t>(inputTensorInfo.GetShape().GetNumDimensions());
+ auto outputRank = inputRank + 1;
+ if((axis < -1 * outputRank) || (outputRank <= axis))
+ {
+ throw ParseException(fmt::format("{}: Axis {} is not within [-{}, {}) range.",
+ CHECK_LOCATION().AsString(), axis, outputRank, outputRank));
+ }
+
+ axis = axis < 0 ? (axis + outputRank) : axis;
- if(axis < 0)
+ std::vector<unsigned int> shape(static_cast<unsigned int>(outputRank));
+ unsigned int inputShapeIndex = 0;
+ for (unsigned int i = 0; i < static_cast<unsigned int>(outputRank); ++i)
+ {
+ if (i == static_cast<unsigned int>(axis))
{
- axis = inputDimSize + axis + 1;
+ shape[i] = 1;
}
-
- std::vector<unsigned int> shape(static_cast<unsigned int>(inputDimSize) + 1);
- unsigned int inputShapeIndex = 0;
- for (unsigned int i = 0; i < static_cast<unsigned int>(inputDimSize + 1); ++i)
+ else
{
- if (i == static_cast<unsigned int>(axis))
- {
- shape[i] = 1;
- }
- else
- {
- shape[i] = inputTensorInfo.GetShape()[inputShapeIndex];
- ++inputShapeIndex;
- }
+ shape[i] = inputTensorInfo.GetShape()[inputShapeIndex];
+ ++inputShapeIndex;
}
-
- reshapeDesc.m_TargetShape = TensorShape(static_cast<unsigned int>(inputDimSize + 1), shape.data());
}
- IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
- ARMNN_ASSERT(layer != nullptr);
-
- reshapeDesc.m_TargetShape = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0}).GetShape();
+ ReshapeDescriptor reshapeDesc;
+ reshapeDesc.m_TargetShape = TensorShape(static_cast<unsigned int>(outputRank), shape.data());
outputTensorInfo.SetShape(reshapeDesc.m_TargetShape);
+ IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
+ ARMNN_ASSERT(layer != nullptr);
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ auto outputTensorIds = GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex);
+ m_TensorInfos[outputTensorIds[0]] = outputTensorInfo;
+
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});