aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/workloads/ClPooling2dWorkload.cpp
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2018-10-19 10:41:35 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-22 16:57:54 +0100
commit69482271d3e02af950d2d0f1947ae6c3eeed537b (patch)
treed0ef56a1ba2d314eb821ce2b6bb8e09773f41a17 /src/backends/cl/workloads/ClPooling2dWorkload.cpp
parentdd6aceaa884815e68ed69fca71de81babd3204da (diff)
downloadarmnn-69482271d3e02af950d2d0f1947ae6c3eeed537b.tar.gz
IVGCVSW-2024: Support NHWC for Pooling2D CpuRef
* Adds implementation to plumb DataLayout parameter for Pooling2D on CpuRef. * Adds unit tests to execute Pooling2D on CpuRef using NHWC data layout. * Refactors original tests to use DataLayoutIndexed and removes duplicate code. Change-Id: Ife7e0861a886cf58a2042e5be20e5b27af4528c9
Diffstat (limited to 'src/backends/cl/workloads/ClPooling2dWorkload.cpp')
-rw-r--r--src/backends/cl/workloads/ClPooling2dWorkload.cpp8
1 files changed, 5 insertions, 3 deletions
diff --git a/src/backends/cl/workloads/ClPooling2dWorkload.cpp b/src/backends/cl/workloads/ClPooling2dWorkload.cpp
index 255f57341e..68512ff980 100644
--- a/src/backends/cl/workloads/ClPooling2dWorkload.cpp
+++ b/src/backends/cl/workloads/ClPooling2dWorkload.cpp
@@ -19,8 +19,10 @@ arm_compute::Status ClPooling2dWorkloadValidate(const TensorInfo& input,
const TensorInfo& output,
const Pooling2dDescriptor& descriptor)
{
- const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
- const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
+ const arm_compute::TensorInfo aclInputInfo =
+ BuildArmComputeTensorInfo(input, descriptor.m_DataLayout.GetDataLayout());
+ const arm_compute::TensorInfo aclOutputInfo =
+ BuildArmComputeTensorInfo(output, descriptor.m_DataLayout.GetDataLayout());
arm_compute::PoolingLayerInfo layerInfo = BuildArmComputePoolingLayerInfo(descriptor);
@@ -36,7 +38,7 @@ ClPooling2dWorkload::ClPooling2dWorkload(
arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
- arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
+ arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout.GetDataLayout());
input.info()->set_data_layout(aclDataLayout);
output.info()->set_data_layout(aclDataLayout);