From ee9e7665a5922f7ec0c5ec24d6ab2ecd88fbcfd6 Mon Sep 17 00:00:00 2001 From: James Conroy Date: Mon, 1 Oct 2018 09:15:19 +0100 Subject: IVGCVSW-1931: Add data layout param for ResizeBilinear * Added data layout parameter to ResizeBilinear descriptor, in order to support NHWC. Change-Id: Ifdbc4529127b7329a056d0a68e2e42b175aeea4a --- src/backends/WorkloadData.cpp | 7 +++++-- src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp | 8 +++++++- 2 files changed, 12 insertions(+), 3 deletions(-) (limited to 'src/backends') diff --git a/src/backends/WorkloadData.cpp b/src/backends/WorkloadData.cpp index c5c607d954..8b28b476b2 100644 --- a/src/backends/WorkloadData.cpp +++ b/src/backends/WorkloadData.cpp @@ -664,8 +664,11 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c } { - const unsigned int inputChannelCount = workloadInfo.m_InputTensorInfos[0].GetShape()[1]; - const unsigned int outputChannelCount = workloadInfo.m_OutputTensorInfos[0].GetShape()[1]; + // DataLayout is NCHW by default (channelsIndex = 1) + const unsigned int channelsIndex = this->m_Parameters.m_DataLayout == armnn::DataLayout::NHWC ? 3 : 1; + + const unsigned int inputChannelCount = workloadInfo.m_InputTensorInfos[0].GetShape()[channelsIndex]; + const unsigned int outputChannelCount = workloadInfo.m_OutputTensorInfos[0].GetShape()[channelsIndex]; if (inputChannelCount != outputChannelCount) { throw InvalidArgumentException( diff --git a/src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp b/src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp index 499466e959..1a330354e4 100644 --- a/src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp +++ b/src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp @@ -8,14 +8,17 @@ #include #include #include +#include #include "ClWorkloadUtils.hpp" +using namespace armnn::armcomputetensorutils; + namespace armnn { ClResizeBilinearFloatWorkload::ClResizeBilinearFloatWorkload(const ResizeBilinearQueueDescriptor& descriptor, - const WorkloadInfo& info) + const WorkloadInfo& info) : FloatWorkload(descriptor, info) { m_Data.ValidateInputsOutputs("ClResizeBilinearFloatWorkload", 1, 1); @@ -23,6 +26,9 @@ ClResizeBilinearFloatWorkload::ClResizeBilinearFloatWorkload(const ResizeBilinea arm_compute::ICLTensor& input = static_cast(m_Data.m_Inputs[0])->GetTensor(); arm_compute::ICLTensor& output = static_cast(m_Data.m_Outputs[0])->GetTensor(); + (&input)->info()->set_data_layout(ConvertDataLayout(m_Data.m_Parameters.m_DataLayout)); + (&output)->info()->set_data_layout(ConvertDataLayout(m_Data.m_Parameters.m_DataLayout)); + m_ResizeBilinearLayer.configure(&input, &output, arm_compute::InterpolationPolicy::BILINEAR, arm_compute::BorderMode::REPLICATE, arm_compute::PixelValue(0.f), arm_compute::SamplingPolicy::TOP_LEFT); -- cgit v1.2.1