From ee9e7665a5922f7ec0c5ec24d6ab2ecd88fbcfd6 Mon Sep 17 00:00:00 2001 From: James Conroy Date: Mon, 1 Oct 2018 09:15:19 +0100 Subject: IVGCVSW-1931: Add data layout param for ResizeBilinear * Added data layout parameter to ResizeBilinear descriptor, in order to support NHWC. Change-Id: Ifdbc4529127b7329a056d0a68e2e42b175aeea4a --- include/armnn/Descriptors.hpp | 8 +++++--- src/backends/WorkloadData.cpp | 7 +++++-- src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp | 8 +++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 30c8144220..2de031e94a 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -92,7 +92,7 @@ struct ViewsDescriptor friend void swap(ViewsDescriptor& first, ViewsDescriptor& second); private: OriginsDescriptor m_Origins; - uint32_t** m_ViewSizes; + uint32_t** m_ViewSizes; }; /// Convenience template to create an OriginsDescriptor to use when creating a Merger layer for performing concatenation @@ -308,10 +308,12 @@ struct ResizeBilinearDescriptor ResizeBilinearDescriptor() : m_TargetWidth(0) , m_TargetHeight(0) + , m_DataLayout(DataLayout::NCHW) {} - uint32_t m_TargetWidth; - uint32_t m_TargetHeight; + uint32_t m_TargetWidth; + uint32_t m_TargetHeight; + DataLayout m_DataLayout; }; struct ReshapeDescriptor diff --git a/src/backends/WorkloadData.cpp b/src/backends/WorkloadData.cpp index c5c607d954..8b28b476b2 100644 --- a/src/backends/WorkloadData.cpp +++ b/src/backends/WorkloadData.cpp @@ -664,8 +664,11 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c } { - const unsigned int inputChannelCount = workloadInfo.m_InputTensorInfos[0].GetShape()[1]; - const unsigned int outputChannelCount = workloadInfo.m_OutputTensorInfos[0].GetShape()[1]; + // DataLayout is NCHW by default (channelsIndex = 1) + const unsigned int channelsIndex = this->m_Parameters.m_DataLayout == armnn::DataLayout::NHWC ? 3 : 1; + + const unsigned int inputChannelCount = workloadInfo.m_InputTensorInfos[0].GetShape()[channelsIndex]; + const unsigned int outputChannelCount = workloadInfo.m_OutputTensorInfos[0].GetShape()[channelsIndex]; if (inputChannelCount != outputChannelCount) { throw InvalidArgumentException( diff --git a/src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp b/src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp index 499466e959..1a330354e4 100644 --- a/src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp +++ b/src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp @@ -8,14 +8,17 @@ #include #include #include +#include #include "ClWorkloadUtils.hpp" +using namespace armnn::armcomputetensorutils; + namespace armnn { ClResizeBilinearFloatWorkload::ClResizeBilinearFloatWorkload(const ResizeBilinearQueueDescriptor& descriptor, - const WorkloadInfo& info) + const WorkloadInfo& info) : FloatWorkload(descriptor, info) { m_Data.ValidateInputsOutputs("ClResizeBilinearFloatWorkload", 1, 1); @@ -23,6 +26,9 @@ ClResizeBilinearFloatWorkload::ClResizeBilinearFloatWorkload(const ResizeBilinea arm_compute::ICLTensor& input = static_cast(m_Data.m_Inputs[0])->GetTensor(); arm_compute::ICLTensor& output = static_cast(m_Data.m_Outputs[0])->GetTensor(); + (&input)->info()->set_data_layout(ConvertDataLayout(m_Data.m_Parameters.m_DataLayout)); + (&output)->info()->set_data_layout(ConvertDataLayout(m_Data.m_Parameters.m_DataLayout)); + m_ResizeBilinearLayer.configure(&input, &output, arm_compute::InterpolationPolicy::BILINEAR, arm_compute::BorderMode::REPLICATE, arm_compute::PixelValue(0.f), arm_compute::SamplingPolicy::TOP_LEFT); -- cgit v1.2.1