diff options
Diffstat (limited to 'src/armnn/layers/ResizeBilinearLayer.cpp')
-rw-r--r-- | src/armnn/layers/ResizeBilinearLayer.cpp | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/src/armnn/layers/ResizeBilinearLayer.cpp b/src/armnn/layers/ResizeBilinearLayer.cpp index 9f0608d11c..fda93da99f 100644 --- a/src/armnn/layers/ResizeBilinearLayer.cpp +++ b/src/armnn/layers/ResizeBilinearLayer.cpp @@ -37,10 +37,14 @@ std::vector<TensorShape> ResizeBilinearLayer::InferOutputShapes(const std::vecto unsigned int outWidth = m_Param.m_TargetWidth; unsigned int outHeight = m_Param.m_TargetHeight; - unsigned int outChannels = inputShape[1]; + unsigned int outChannels = inputShape[m_Param.m_DataLayout.GetChannelsIndex()]; unsigned int outBatch = inputShape[0]; - return std::vector<TensorShape>({ TensorShape({outBatch, outChannels, outHeight, outWidth}) }); + TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ? + TensorShape( { outBatch, outHeight, outWidth, outChannels } ) : + TensorShape( { outBatch, outChannels, outHeight, outWidth }); + + return std::vector<TensorShape>({ tensorShape }); } void ResizeBilinearLayer::ValidateTensorShapesFromInputs() |