From 8eb675eb77865b5d2491f5b2d650ce993cab738c Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Wed, 17 Oct 2018 14:43:29 +0100 Subject: IVGCVSW-2038 + IVGCVSW-2039 + IVGCVSW-2040 Add NHWC support to the Float32 and UInt8 BatchNormalization workloads * Enabled NHWC support in RefBatchNormalizationFloat32Workload * Added NHWC unit tests for both FP32 and U8 * Refactored the existing unit tests Change-Id: I6aa18f1dcc0666b80a17a7ed229cf53607bae147 --- src/backends/reference/test/RefLayerTests.cpp | 2 ++ src/backends/reference/workloads/BatchNormImpl.hpp | 36 +++++++++++++--------- .../RefBatchNormalizationFloat32Workload.hpp | 2 +- 3 files changed, 24 insertions(+), 16 deletions(-) (limited to 'src/backends/reference') diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 2815e342c0..6cfa4a3926 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -176,7 +176,9 @@ ARMNN_AUTO_TEST_CASE(MultiplicationBroadcast1DVectorUint8, MultiplicationBroadca // Batch Norm ARMNN_AUTO_TEST_CASE(BatchNorm, BatchNormTest) +ARMNN_AUTO_TEST_CASE(BatchNormNhwc, BatchNormNhwcTest) ARMNN_AUTO_TEST_CASE(BatchNormUint8, BatchNormUint8Test) +ARMNN_AUTO_TEST_CASE(BatchNormUint8Nhwc, BatchNormUint8NhwcTest) // Resize Bilinear - NCHW ARMNN_AUTO_TEST_CASE(SimpleResizeBilinear, SimpleResizeBilinearTest) diff --git a/src/backends/reference/workloads/BatchNormImpl.hpp b/src/backends/reference/workloads/BatchNormImpl.hpp index a7579c8373..fbcb2fdf5a 100644 --- a/src/backends/reference/workloads/BatchNormImpl.hpp +++ b/src/backends/reference/workloads/BatchNormImpl.hpp @@ -6,6 +6,7 @@ #pragma once #include "RefWorkloadUtils.hpp" +#include "TensorBufferArrayView.hpp" #include @@ -15,16 +16,27 @@ namespace armnn { template -static void BatchNormImpl(NormData data, +static void BatchNormImpl(NormData data, const float* varIn, const float* meanIn, const float* gammaIn, const float* betaIn, - float * outputData, - const float * inputData) + float* outputData, + const float* inputData) { - const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]); - for (unsigned int c = 0; c < inputInfo0.GetShape()[1]; c++) + const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]); + const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]); + + TensorBufferArrayView input(inputInfo.GetShape(), + inputData, + data.m_Parameters.m_DataLayout); + TensorBufferArrayView output(outputInfo.GetShape(), + outputData, + data.m_Parameters.m_DataLayout); + + DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout); + + for (unsigned int c = 0; c < inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; c++) { float var = varIn[c]; float mean = meanIn[c]; @@ -34,19 +46,13 @@ static void BatchNormImpl(NormData data, float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps); float add = beta - mult * mean; - for (unsigned int n = 0; n < inputInfo0.GetShape()[0]; n++) + for (unsigned int n = 0; n < inputInfo.GetShape()[0]; n++) { - for (unsigned int j = 0; j < inputInfo0.GetShape()[2]; j++) + for (unsigned int h = 0; h < inputInfo.GetShape()[dataLayout.GetHeightIndex()]; h++) { - for (unsigned int i = 0; i < inputInfo0.GetShape()[3]; i++) + for (unsigned int w = 0; w < inputInfo.GetShape()[dataLayout.GetWidthIndex()]; w++) { - unsigned int index = i + - j*inputInfo0.GetShape()[3] + - c*inputInfo0.GetShape()[3] * inputInfo0.GetShape()[2] + - n*inputInfo0.GetShape()[3] * inputInfo0.GetShape()[2] - * inputInfo0.GetShape()[1]; - - outputData[index] = mult * inputData[index] + add; + output.Get(n, c, h, w) = mult * input.Get(n, c, h, w) + add; } } } diff --git a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp b/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp index 17f80ca5e0..b51d94f979 100644 --- a/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp +++ b/src/backends/reference/workloads/RefBatchNormalizationFloat32Workload.hpp @@ -15,7 +15,7 @@ class RefBatchNormalizationFloat32Workload : public Float32Workload