aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon')
-rw-r--r--src/backends/neon/test/NeonCreateWorkloadTests.cpp31
-rw-r--r--src/backends/neon/test/NeonLayerTests.cpp1
-rw-r--r--src/backends/neon/workloads/NeonBatchNormalizationFloatWorkload.cpp24
3 files changed, 42 insertions, 14 deletions
diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
index a588a3ecc8..8d5574c6a7 100644
--- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp
+++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
@@ -153,30 +153,45 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload)
}
template <typename BatchNormalizationWorkloadType, typename armnn::DataType DataType>
-static void NeonCreateBatchNormalizationWorkloadTest()
+static void NeonCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
{
Graph graph;
NeonWorkloadFactory factory;
- auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>(factory, graph);
+ auto workload = CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>
+ (factory, graph, dataLayout);
// Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
BatchNormalizationQueueDescriptor queueDescriptor = workload->GetData();
auto inputHandle = boost::polymorphic_downcast<INeonTensorHandle*>(queueDescriptor.m_Inputs[0]);
auto outputHandle = boost::polymorphic_downcast<INeonTensorHandle*>(queueDescriptor.m_Outputs[0]);
- BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo({2, 3, 1, 1}, DataType)));
- BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo({2, 3, 1, 1}, DataType)));
+
+ TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{2, 3, 4, 4} : TensorShape{2, 4, 4, 3};
+ TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{2, 3, 4, 4} : TensorShape{2, 4, 4, 3};
+
+ BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo(inputShape, DataType)));
+ BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo(outputShape, DataType)));
}
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat16Workload)
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat16NchwWorkload)
+{
+ NeonCreateBatchNormalizationWorkloadTest<NeonBatchNormalizationFloatWorkload, DataType::Float16>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat16NhwcWorkload)
{
- NeonCreateBatchNormalizationWorkloadTest<NeonBatchNormalizationFloatWorkload, DataType::Float16>();
+ NeonCreateBatchNormalizationWorkloadTest<NeonBatchNormalizationFloatWorkload, DataType::Float16>(DataLayout::NHWC);
}
#endif
-BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloatWorkload)
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloatNchwWorkload)
+{
+ NeonCreateBatchNormalizationWorkloadTest<NeonBatchNormalizationFloatWorkload, DataType::Float32>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloatNhwcWorkload)
{
- NeonCreateBatchNormalizationWorkloadTest<NeonBatchNormalizationFloatWorkload, DataType::Float32>();
+ NeonCreateBatchNormalizationWorkloadTest<NeonBatchNormalizationFloatWorkload, DataType::Float32>(DataLayout::NHWC);
}
template <typename armnn::DataType DataType>
diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp
index 31ee7d87c1..568a2367c6 100644
--- a/src/backends/neon/test/NeonLayerTests.cpp
+++ b/src/backends/neon/test/NeonLayerTests.cpp
@@ -338,6 +338,7 @@ ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1DVector, MultiplicationBroadcast1DV
// Batch Norm
ARMNN_AUTO_TEST_CASE(BatchNorm, BatchNormTest)
+ARMNN_AUTO_TEST_CASE(BatchNormNhwc, BatchNormNhwcTest)
// Constant
ARMNN_AUTO_TEST_CASE(Constant, ConstantTest)
diff --git a/src/backends/neon/workloads/NeonBatchNormalizationFloatWorkload.cpp b/src/backends/neon/workloads/NeonBatchNormalizationFloatWorkload.cpp
index f7056a515b..95cfdce9b4 100644
--- a/src/backends/neon/workloads/NeonBatchNormalizationFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonBatchNormalizationFloatWorkload.cpp
@@ -21,12 +21,20 @@ arm_compute::Status NeonBatchNormalizationValidate(const TensorInfo& input,
const TensorInfo& gamma,
const BatchNormalizationDescriptor& descriptor)
{
- const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
- const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
- const arm_compute::TensorInfo aclMeanInfo = BuildArmComputeTensorInfo(mean);
- const arm_compute::TensorInfo aclVarInfo = BuildArmComputeTensorInfo(var);
- const arm_compute::TensorInfo aclBetaInfo = BuildArmComputeTensorInfo(beta);
- const arm_compute::TensorInfo aclGammaInfo = BuildArmComputeTensorInfo(gamma);
+ const DataLayout dataLayout = descriptor.m_DataLayout.GetDataLayout();
+
+ const arm_compute::TensorInfo aclInputInfo =
+ armcomputetensorutils::BuildArmComputeTensorInfo(input, dataLayout);
+ const arm_compute::TensorInfo aclOutputInfo =
+ armcomputetensorutils::BuildArmComputeTensorInfo(output, dataLayout);
+ const arm_compute::TensorInfo aclMeanInfo =
+ armcomputetensorutils::BuildArmComputeTensorInfo(mean, dataLayout);
+ const arm_compute::TensorInfo aclVarInfo =
+ armcomputetensorutils::BuildArmComputeTensorInfo(var, dataLayout);
+ const arm_compute::TensorInfo aclBetaInfo =
+ armcomputetensorutils::BuildArmComputeTensorInfo(beta, dataLayout);
+ const arm_compute::TensorInfo aclGammaInfo =
+ armcomputetensorutils::BuildArmComputeTensorInfo(gamma, dataLayout);
return arm_compute::NEBatchNormalizationLayer::validate(&aclInputInfo,
&aclOutputInfo,
@@ -46,6 +54,10 @@ NeonBatchNormalizationFloatWorkload::NeonBatchNormalizationFloatWorkload(
arm_compute::ITensor& input = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ITensor& output = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+ arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout.GetDataLayout());
+ input.info()->set_data_layout(aclDataLayout);
+ output.info()->set_data_layout(aclDataLayout);
+
m_Mean = std::make_unique<arm_compute::Tensor>();
BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());