From 8eb675eb77865b5d2491f5b2d650ce993cab738c Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Wed, 17 Oct 2018 14:43:29 +0100 Subject: IVGCVSW-2038 + IVGCVSW-2039 + IVGCVSW-2040 Add NHWC support to the Float32 and UInt8 BatchNormalization workloads * Enabled NHWC support in RefBatchNormalizationFloat32Workload * Added NHWC unit tests for both FP32 and U8 * Refactored the existing unit tests Change-Id: I6aa18f1dcc0666b80a17a7ed229cf53607bae147 --- src/backends/reference/workloads/BatchNormImpl.hpp | 36 +++++++++++++--------- 1 file changed, 21 insertions(+), 15 deletions(-) (limited to 'src/backends/reference/workloads/BatchNormImpl.hpp') diff --git a/src/backends/reference/workloads/BatchNormImpl.hpp b/src/backends/reference/workloads/BatchNormImpl.hpp index a7579c8373..fbcb2fdf5a 100644 --- a/src/backends/reference/workloads/BatchNormImpl.hpp +++ b/src/backends/reference/workloads/BatchNormImpl.hpp @@ -6,6 +6,7 @@ #pragma once #include "RefWorkloadUtils.hpp" +#include "TensorBufferArrayView.hpp" #include @@ -15,16 +16,27 @@ namespace armnn { template -static void BatchNormImpl(NormData data, +static void BatchNormImpl(NormData data, const float* varIn, const float* meanIn, const float* gammaIn, const float* betaIn, - float * outputData, - const float * inputData) + float* outputData, + const float* inputData) { - const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]); - for (unsigned int c = 0; c < inputInfo0.GetShape()[1]; c++) + const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]); + const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]); + + TensorBufferArrayView input(inputInfo.GetShape(), + inputData, + data.m_Parameters.m_DataLayout); + TensorBufferArrayView output(outputInfo.GetShape(), + outputData, + data.m_Parameters.m_DataLayout); + + DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout); + + for (unsigned int c = 0; c < inputInfo.GetShape()[dataLayout.GetChannelsIndex()]; c++) { float var = varIn[c]; float mean = meanIn[c]; @@ -34,19 +46,13 @@ static void BatchNormImpl(NormData data, float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps); float add = beta - mult * mean; - for (unsigned int n = 0; n < inputInfo0.GetShape()[0]; n++) + for (unsigned int n = 0; n < inputInfo.GetShape()[0]; n++) { - for (unsigned int j = 0; j < inputInfo0.GetShape()[2]; j++) + for (unsigned int h = 0; h < inputInfo.GetShape()[dataLayout.GetHeightIndex()]; h++) { - for (unsigned int i = 0; i < inputInfo0.GetShape()[3]; i++) + for (unsigned int w = 0; w < inputInfo.GetShape()[dataLayout.GetWidthIndex()]; w++) { - unsigned int index = i + - j*inputInfo0.GetShape()[3] + - c*inputInfo0.GetShape()[3] * inputInfo0.GetShape()[2] + - n*inputInfo0.GetShape()[3] * inputInfo0.GetShape()[2] - * inputInfo0.GetShape()[1]; - - outputData[index] = mult * inputData[index] + add; + output.Get(n, c, h, w) = mult * input.Get(n, c, h, w) + add; } } } -- cgit v1.2.1