aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp32
1 files changed, 23 insertions, 9 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index cd0e9214c2..31e808fd6e 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -295,7 +295,7 @@ void CalcPadding(uint32_t inputSize,
}
}
-armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr)
+armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::vector<unsigned int>& shapes)
{
armnn::DataType type;
CHECK_TENSOR_PTR(tensorPtr);
@@ -345,17 +345,21 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr)
}
}
- auto const & dimensions = AsUnsignedVector(tensorPtr->shape);
-
// two statements (on purpose) for easier debugging:
- armnn::TensorInfo result(static_cast<unsigned int>(tensorPtr->shape.size()),
- dimensions.data(),
+ armnn::TensorInfo result(static_cast<unsigned int>(shapes.size()),
+ shapes.data(),
type,
quantizationScale,
quantizationOffset);
return result;
}
+armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr)
+{
+ auto const & dimensions = AsUnsignedVector(tensorPtr->shape);
+ return ToTensorInfo(tensorPtr, dimensions);
+}
+
template<typename T>
std::pair<armnn::ConstTensor, std::unique_ptr<T[]>>
CreateConstTensorImpl(TfLiteParser::BufferRawPtr bufferPtr,
@@ -1796,10 +1800,17 @@ void TfLiteParser::ParseDetectionPostProcess(size_t subgraphIndex, size_t operat
BOOST_ASSERT(layer != nullptr);
- // Register outputs
+ // The model does not specify the output shapes.
+ // The output shapes are calculated from the max_detection and max_classes_per_detection.
+ unsigned int numDetectedBox = desc.m_MaxDetections * desc.m_MaxClassesPerDetection;
+ m_OverridenOutputShapes.push_back({ 1, numDetectedBox, 4 });
+ m_OverridenOutputShapes.push_back({ 1, numDetectedBox });
+ m_OverridenOutputShapes.push_back({ 1, numDetectedBox });
+ m_OverridenOutputShapes.push_back({ 1 });
+
for (unsigned int i = 0 ; i < outputs.size() ; ++i)
{
- armnn::TensorInfo detectionBoxOutputTensorInfo = ToTensorInfo(outputs[i]);
+ armnn::TensorInfo detectionBoxOutputTensorInfo = ToTensorInfo(outputs[i], m_OverridenOutputShapes[i]);
layer->GetOutputSlot(i).SetTensorInfo(detectionBoxOutputTensorInfo);
}
@@ -2232,12 +2243,15 @@ BindingPointInfo TfLiteParser::GetNetworkOutputBindingInfo(size_t subgraphId,
{
CHECK_SUBGRAPH(m_Model, subgraphId);
auto outputs = GetSubgraphOutputs(m_Model, subgraphId);
- for (auto const & output : outputs)
+ for (unsigned int i = 0; i < outputs.size(); ++i)
{
+ auto const output = outputs[i];
if (output.second->name == name)
{
auto bindingId = GenerateLayerBindingId(subgraphId, output.first);
- return std::make_pair(bindingId, ToTensorInfo(output.second));
+ std::vector<unsigned int> shape = m_OverridenOutputShapes.size() > 0 ?
+ m_OverridenOutputShapes[i] : AsUnsignedVector(output.second->shape);
+ return std::make_pair(bindingId, ToTensorInfo(output.second, shape));
}
}