aboutsummaryrefslogtreecommitdiff
path: root/src/backends/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/WorkloadData.cpp')
-rw-r--r--src/backends/WorkloadData.cpp7
1 files changed, 5 insertions, 2 deletions
diff --git a/src/backends/WorkloadData.cpp b/src/backends/WorkloadData.cpp
index c5c607d954..8b28b476b2 100644
--- a/src/backends/WorkloadData.cpp
+++ b/src/backends/WorkloadData.cpp
@@ -664,8 +664,11 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
}
{
- const unsigned int inputChannelCount = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
- const unsigned int outputChannelCount = workloadInfo.m_OutputTensorInfos[0].GetShape()[1];
+ // DataLayout is NCHW by default (channelsIndex = 1)
+ const unsigned int channelsIndex = this->m_Parameters.m_DataLayout == armnn::DataLayout::NHWC ? 3 : 1;
+
+ const unsigned int inputChannelCount = workloadInfo.m_InputTensorInfos[0].GetShape()[channelsIndex];
+ const unsigned int outputChannelCount = workloadInfo.m_OutputTensorInfos[0].GetShape()[channelsIndex];
if (inputChannelCount != outputChannelCount)
{
throw InvalidArgumentException(