aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/ResizeBilinearLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/ResizeBilinearLayer.cpp')
-rw-r--r--src/armnn/layers/ResizeBilinearLayer.cpp8
1 files changed, 6 insertions, 2 deletions
diff --git a/src/armnn/layers/ResizeBilinearLayer.cpp b/src/armnn/layers/ResizeBilinearLayer.cpp
index 9f0608d11c..fda93da99f 100644
--- a/src/armnn/layers/ResizeBilinearLayer.cpp
+++ b/src/armnn/layers/ResizeBilinearLayer.cpp
@@ -37,10 +37,14 @@ std::vector<TensorShape> ResizeBilinearLayer::InferOutputShapes(const std::vecto
unsigned int outWidth = m_Param.m_TargetWidth;
unsigned int outHeight = m_Param.m_TargetHeight;
- unsigned int outChannels = inputShape[1];
+ unsigned int outChannels = inputShape[m_Param.m_DataLayout.GetChannelsIndex()];
unsigned int outBatch = inputShape[0];
- return std::vector<TensorShape>({ TensorShape({outBatch, outChannels, outHeight, outWidth}) });
+ TensorShape tensorShape = m_Param.m_DataLayout == armnn::DataLayout::NHWC ?
+ TensorShape( { outBatch, outHeight, outWidth, outChannels } ) :
+ TensorShape( { outBatch, outChannels, outHeight, outWidth });
+
+ return std::vector<TensorShape>({ tensorShape });
}
void ResizeBilinearLayer::ValidateTensorShapesFromInputs()