diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 32 |
1 files changed, 23 insertions, 9 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index cd0e9214c2..31e808fd6e 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -295,7 +295,7 @@ void CalcPadding(uint32_t inputSize, } } -armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr) +armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::vector<unsigned int>& shapes) { armnn::DataType type; CHECK_TENSOR_PTR(tensorPtr); @@ -345,17 +345,21 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr) } } - auto const & dimensions = AsUnsignedVector(tensorPtr->shape); - // two statements (on purpose) for easier debugging: - armnn::TensorInfo result(static_cast<unsigned int>(tensorPtr->shape.size()), - dimensions.data(), + armnn::TensorInfo result(static_cast<unsigned int>(shapes.size()), + shapes.data(), type, quantizationScale, quantizationOffset); return result; } +armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr) +{ + auto const & dimensions = AsUnsignedVector(tensorPtr->shape); + return ToTensorInfo(tensorPtr, dimensions); +} + template<typename T> std::pair<armnn::ConstTensor, std::unique_ptr<T[]>> CreateConstTensorImpl(TfLiteParser::BufferRawPtr bufferPtr, @@ -1796,10 +1800,17 @@ void TfLiteParser::ParseDetectionPostProcess(size_t subgraphIndex, size_t operat BOOST_ASSERT(layer != nullptr); - // Register outputs + // The model does not specify the output shapes. + // The output shapes are calculated from the max_detection and max_classes_per_detection. + unsigned int numDetectedBox = desc.m_MaxDetections * desc.m_MaxClassesPerDetection; + m_OverridenOutputShapes.push_back({ 1, numDetectedBox, 4 }); + m_OverridenOutputShapes.push_back({ 1, numDetectedBox }); + m_OverridenOutputShapes.push_back({ 1, numDetectedBox }); + m_OverridenOutputShapes.push_back({ 1 }); + for (unsigned int i = 0 ; i < outputs.size() ; ++i) { - armnn::TensorInfo detectionBoxOutputTensorInfo = ToTensorInfo(outputs[i]); + armnn::TensorInfo detectionBoxOutputTensorInfo = ToTensorInfo(outputs[i], m_OverridenOutputShapes[i]); layer->GetOutputSlot(i).SetTensorInfo(detectionBoxOutputTensorInfo); } @@ -2232,12 +2243,15 @@ BindingPointInfo TfLiteParser::GetNetworkOutputBindingInfo(size_t subgraphId, { CHECK_SUBGRAPH(m_Model, subgraphId); auto outputs = GetSubgraphOutputs(m_Model, subgraphId); - for (auto const & output : outputs) + for (unsigned int i = 0; i < outputs.size(); ++i) { + auto const output = outputs[i]; if (output.second->name == name) { auto bindingId = GenerateLayerBindingId(subgraphId, output.first); - return std::make_pair(bindingId, ToTensorInfo(output.second)); + std::vector<unsigned int> shape = m_OverridenOutputShapes.size() > 0 ? + m_OverridenOutputShapes[i] : AsUnsignedVector(output.second->shape); + return std::make_pair(bindingId, ToTensorInfo(output.second, shape)); } } |