aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp130
1 files changed, 104 insertions, 26 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index fbbc5acb3c..95e0e0af6e 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -363,7 +363,7 @@ void CalcPadding(uint32_t inputSize,
}
armnn::TensorInfo ToTensorInfo(TfLiteParserImpl::TensorRawPtr tensorPtr,
- const std::vector<unsigned int>& shapes,
+ const std::vector<unsigned int>& shape,
const bool outputTensor = false)
{
armnn::DataType type;
@@ -412,14 +412,53 @@ armnn::TensorInfo ToTensorInfo(TfLiteParserImpl::TensorRawPtr tensorPtr,
location.AsString()));
}
}
- std::vector<unsigned int> safeShape = shapes;
- bool isDynamic = false;
- if (safeShape.size() == 0)
+ TensorShape tensorShape;
+
+ std::vector<unsigned int> safeShape = shape;
+ if (shape.size() == 0)
{
safeShape.push_back(1);
- if (outputTensor)
+ }
+
+ if (!outputTensor)
+ {
+ tensorShape = TensorShape(armnn::numeric_cast<unsigned int>(safeShape.size()), safeShape.data());
+ }
+ else
+ {
+ unsigned long shapeSignatureSize = tensorPtr->shape_signature.size();
+
+ // If a shape signature exists we will use that to infer dynamic tensors
+ if (shapeSignatureSize != 0)
{
- isDynamic = true;
+ // If the shape is incompatible with the shape signature override the shape
+ if (shapeSignatureSize != shape.size())
+ {
+ safeShape = {};
+
+ for (unsigned int i = 0; i < shapeSignatureSize; ++i)
+ {
+ unsigned int dim = tensorPtr->shape_signature[i] > -1 ?
+ static_cast<unsigned int>(tensorPtr->shape_signature[i]) : 0;
+ safeShape.push_back(dim);
+ }
+ }
+
+ bool dimMask[tensorPtr->shape_signature.size()];
+ for (unsigned int i = 0; i < tensorPtr->shape_signature.size(); ++i)
+ {
+ dimMask[i] = tensorPtr->shape_signature[i] == -1 ? false : true;
+ }
+ tensorShape = TensorShape(static_cast<unsigned int>(safeShape.size()), safeShape.data(), dimMask);
+ }
+ // If there is no shape signature treat the tensor as dynamic if the shape has a size of zero
+ else if (shape.size() == 0)
+ {
+ tensorShape = TensorShape(1, false);
+ }
+ else
+ {
+ tensorShape = TensorShape(armnn::numeric_cast<unsigned int>(shape.size()), shape.data());
}
}
@@ -444,12 +483,6 @@ armnn::TensorInfo ToTensorInfo(TfLiteParserImpl::TensorRawPtr tensorPtr,
quantizationOffset = armnn::numeric_cast<int32_t>(tensorPtr->quantization->zero_point[0]);
}
- TensorShape tensorShape(armnn::numeric_cast<unsigned int>(safeShape.size()),
- safeShape.data());
- if (isDynamic)
- {
- tensorShape = TensorShape(1, false);
- }
armnn::TensorInfo result(tensorShape,
type,
quantizationScale,
@@ -467,12 +500,6 @@ armnn::TensorInfo ToTensorInfo(TfLiteParserImpl::TensorRawPtr tensorPtr,
std::back_inserter(quantizationScales));
// QSymmS8 Per-axis
- TensorShape tensorShape(armnn::numeric_cast<unsigned int>(safeShape.size()),
- safeShape.data());
- if (isDynamic)
- {
- tensorShape = TensorShape(1, false);
- }
armnn::TensorInfo result(tensorShape,
type,
quantizationScales,
@@ -482,12 +509,6 @@ armnn::TensorInfo ToTensorInfo(TfLiteParserImpl::TensorRawPtr tensorPtr,
}
else
{
- TensorShape tensorShape(armnn::numeric_cast<unsigned int>(safeShape.size()),
- safeShape.data());
- if (isDynamic)
- {
- tensorShape = TensorShape(1, false);
- }
armnn::TensorInfo result(tensorShape,
type,
quantizationScale,
@@ -695,6 +716,15 @@ INetworkPtr TfLiteParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>
return CreateNetworkFromModel();
}
+
+armnn::INetworkPtr TfLiteParserImpl::LoadModel(std::unique_ptr<tflite::ModelT> model)
+{
+ ResetParser();
+ m_Model = std::move(model);
+
+ return CreateNetworkFromModel();
+}
+
INetworkPtr TfLiteParserImpl::CreateNetworkFromModel()
{
@@ -1116,7 +1146,44 @@ void TfLiteParserImpl::ParseExpandDims(size_t subgraphIndex, size_t operatorInde
CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
ReshapeDescriptor reshapeDesc;
- reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
+
+ if (outputTensorInfo.GetShape().AreAllDimensionsSpecified())
+ {
+ reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
+ }
+ else
+ {
+ int32_t axis = inputs[1]->shape[0];
+
+ int32_t inputDimSize = static_cast<int32_t>(inputTensorInfo.GetShape().GetNumDimensions());
+
+ if (axis > inputDimSize || axis < 0 - (inputDimSize + 1))
+ {
+ throw ParseException("axis must be in range [0 - (inputDimSize + 1), inputDimSize] inclusive");
+ }
+
+ if(axis < 0)
+ {
+ axis = inputDimSize + axis + 1;
+ }
+
+ unsigned int shape[inputDimSize + 1];
+ unsigned int inputShapeIndex = 0;
+ for (unsigned int i = 0; i < static_cast<unsigned int>(inputDimSize + 1); ++i)
+ {
+ if (i == static_cast<unsigned int>(axis))
+ {
+ shape[i] = 1;
+ }
+ else
+ {
+ shape[i] = inputTensorInfo.GetShape()[inputShapeIndex];
+ ++inputShapeIndex;
+ }
+ }
+
+ reshapeDesc.m_TargetShape = TensorShape(static_cast<unsigned int>(inputDimSize + 1), shape);
+ }
IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
ARMNN_ASSERT(layer != nullptr);
@@ -2790,13 +2857,24 @@ void TfLiteParserImpl::ParseUnpack(size_t subgraphIndex, size_t operatorIndex)
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+ std::vector<unsigned int> reshapeDims;
+ for (unsigned int axis = 0; axis < splitOutShape.GetNumDimensions(); ++axis)
+ {
+ if (axis != unpackAxis)
+ {
+ reshapeDims.push_back(splitOutShape[axis]);
+ }
+ }
+
+ TensorShape reshapeOutputShape(splitOutShape.GetNumDimensions() -1, reshapeDims.data());
+
// Create reshape to remove the unpacked dimension for unpack operator of each output from Splitter.
for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
{
armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[k], true);
std::string reshapeLayerName = fmt::format("Reshape_for:{}", layer->GetName());
armnn::ReshapeDescriptor desc;
- desc.m_TargetShape = outputTensorInfo.GetShape();
+ desc.m_TargetShape = reshapeOutputShape;
armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str());
layer->GetOutputSlot(k).SetTensorInfo(armnn::TensorInfo(splitOutShape,