aboutsummaryrefslogtreecommitdiff
path: root/src/backends
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2018-10-01 09:15:19 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:58 +0100
commitee9e7665a5922f7ec0c5ec24d6ab2ecd88fbcfd6 (patch)
treebc1e65f452ac3997d30cc647da0d3531910a310f /src/backends
parent616e775763280992de92287b129dc335be91a24c (diff)
downloadarmnn-ee9e7665a5922f7ec0c5ec24d6ab2ecd88fbcfd6.tar.gz
IVGCVSW-1931: Add data layout param for ResizeBilinear
* Added data layout parameter to ResizeBilinear descriptor, in order to support NHWC. Change-Id: Ifdbc4529127b7329a056d0a68e2e42b175aeea4a
Diffstat (limited to 'src/backends')
-rw-r--r--src/backends/WorkloadData.cpp7
-rw-r--r--src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp8
2 files changed, 12 insertions, 3 deletions
diff --git a/src/backends/WorkloadData.cpp b/src/backends/WorkloadData.cpp
index c5c607d954..8b28b476b2 100644
--- a/src/backends/WorkloadData.cpp
+++ b/src/backends/WorkloadData.cpp
@@ -664,8 +664,11 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
}
{
- const unsigned int inputChannelCount = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
- const unsigned int outputChannelCount = workloadInfo.m_OutputTensorInfos[0].GetShape()[1];
+ // DataLayout is NCHW by default (channelsIndex = 1)
+ const unsigned int channelsIndex = this->m_Parameters.m_DataLayout == armnn::DataLayout::NHWC ? 3 : 1;
+
+ const unsigned int inputChannelCount = workloadInfo.m_InputTensorInfos[0].GetShape()[channelsIndex];
+ const unsigned int outputChannelCount = workloadInfo.m_OutputTensorInfos[0].GetShape()[channelsIndex];
if (inputChannelCount != outputChannelCount)
{
throw InvalidArgumentException(
diff --git a/src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp b/src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp
index 499466e959..1a330354e4 100644
--- a/src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp
@@ -8,14 +8,17 @@
#include <backends/CpuTensorHandle.hpp>
#include <backends/cl/ClLayerSupport.hpp>
#include <backends/aclCommon/ArmComputeUtils.hpp>
+#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
#include "ClWorkloadUtils.hpp"
+using namespace armnn::armcomputetensorutils;
+
namespace armnn
{
ClResizeBilinearFloatWorkload::ClResizeBilinearFloatWorkload(const ResizeBilinearQueueDescriptor& descriptor,
- const WorkloadInfo& info)
+ const WorkloadInfo& info)
: FloatWorkload<ResizeBilinearQueueDescriptor>(descriptor, info)
{
m_Data.ValidateInputsOutputs("ClResizeBilinearFloatWorkload", 1, 1);
@@ -23,6 +26,9 @@ ClResizeBilinearFloatWorkload::ClResizeBilinearFloatWorkload(const ResizeBilinea
arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+ (&input)->info()->set_data_layout(ConvertDataLayout(m_Data.m_Parameters.m_DataLayout));
+ (&output)->info()->set_data_layout(ConvertDataLayout(m_Data.m_Parameters.m_DataLayout));
+
m_ResizeBilinearLayer.configure(&input, &output, arm_compute::InterpolationPolicy::BILINEAR,
arm_compute::BorderMode::REPLICATE, arm_compute::PixelValue(0.f),
arm_compute::SamplingPolicy::TOP_LEFT);