From 9b3983299f882c8d84c5abd0d40ca75a801ea7f2 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Wed, 22 May 2019 17:21:49 +0100 Subject: IVGCVSW-3025: Refactor reference Convolution2d workload * Refactored RefConvolution2dWorkload to support all DataTypes through Encoders and Decoders. * Added Convolute function to ConvImpl that uses Encoders and Decoders to support all DataTypes. * Deleted RefConvolution2dFloat32Workload and RefConvolution2dUint8Workload. Signed-off-by: Mike Kelly Signed-off-by: Teresa Charlin Change-Id: Ic5ef0f499d08b948fa65fdee54b5f681fd0b1c05 --- src/backends/reference/workloads/ConvImpl.cpp | 173 ++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) (limited to 'src/backends/reference/workloads/ConvImpl.cpp') diff --git a/src/backends/reference/workloads/ConvImpl.cpp b/src/backends/reference/workloads/ConvImpl.cpp index 8743a2bd0d..6a5ac535e4 100644 --- a/src/backends/reference/workloads/ConvImpl.cpp +++ b/src/backends/reference/workloads/ConvImpl.cpp @@ -68,4 +68,177 @@ int32_t QuantizedMultiplierSmallerThanOne::RoundingDivideByPOT(int32_t x, int ex return (x >> exponent) + (remainder > threshold ? 1 : 0); } +inline unsigned int GetOffset(DataLayout& dataLayout, const TensorShape& shape, unsigned int b, unsigned int c, + unsigned int h, unsigned int w) +{ + switch (dataLayout) + { + case DataLayout::NHWC: + b *= shape[1] * shape[2] * shape[3]; + h *= shape[2] * shape[3]; + w *= shape[3]; + break; + case DataLayout::NCHW: + default: + b *= shape[1] * shape[2] * shape[3]; + c *= shape[2] * shape[3]; + h *= shape[3]; + break; + } + return b + c + h + w; +} + +void Convolve(const TensorShape& rInputShape, + Decoder& rInputDecoder, + const TensorShape& rOutputShape, + Encoder& rOutputEncoder, + const TensorShape& rFilterShape, + Decoder& rFilterDecoder, + bool biasEnabled, + Decoder* pBiasDecoder, + DataLayout dataLayout, + unsigned int paddingTop, + unsigned int paddingLeft, + unsigned int xStride, + unsigned int yStride, + unsigned int xDilation, + unsigned int yDilation, + bool depthwise) +{ + if (biasEnabled && !pBiasDecoder) + { + throw InvalidArgumentException("Bias is enabled but the bias data is invalid"); + } + const armnnUtils::DataLayoutIndexed dataLayoutIndexed(dataLayout); + + const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex(); + const unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex(); + const unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex(); + + unsigned int depthMultiplier = depthwise ? rFilterShape[0] : 1; + unsigned int inputChannels = depthwise ? rFilterShape[1] : rFilterShape[channelsIndex]; + unsigned int outputChannels = depthwise ? inputChannels * depthMultiplier : rFilterShape[0]; + + unsigned int batchSize = rOutputShape[0]; + unsigned int outputHeight = rOutputShape[heightIndex]; + unsigned int outputWidth = rOutputShape[widthIndex]; + unsigned int inputHeight = rInputShape[heightIndex]; + unsigned int inputWidth = rInputShape[widthIndex]; + + unsigned int filterHeight = depthwise ? rFilterShape[2] : rFilterShape[heightIndex]; + unsigned int filterWidth = depthwise ? rFilterShape[3] : rFilterShape[widthIndex]; + + for (unsigned int batchIdx = 0; batchIdx < batchSize; batchIdx++) + { + for (unsigned int cOutput = 0; cOutput < outputChannels; cOutput++) + { + for (unsigned int yOutput = 0; yOutput < outputHeight; yOutput++) + { + for (unsigned int xOutput = 0; xOutput < outputWidth; xOutput++) + { + // This loop goes over each output element. + float sum = 0.0f; + + // For depthwise, each output channel corresponds to exactly one input channel. + // For normal, must loop over each input channel. + for (unsigned int cInput = 0; cInput < (depthwise ? 1 : inputChannels); cInput++) + { + unsigned int depthwiseMultiplierIdx = 0; + if (depthwise) + { + cInput = cOutput / depthMultiplier; + depthwiseMultiplierIdx = cOutput % depthMultiplier; + } + + for (unsigned int yFilter = 0; yFilter < filterHeight; yFilter++) + { + for (unsigned int xFilter = 0; xFilter < filterWidth; xFilter++) + { + // This loop goes over each input element for each output element. + unsigned int filterIndex = 0; + + // Since dimensionality of kernel depends on depthwiseness, so does index. + if (depthwise) + { + filterIndex = depthwiseMultiplierIdx * filterWidth * filterHeight * inputChannels + + cInput * filterWidth * filterHeight + + yFilter * filterWidth + + xFilter; + } + else + { + if (dataLayout == DataLayout::NHWC) + { + filterIndex = cOutput * filterHeight * filterWidth * inputChannels + + yFilter * filterWidth * inputChannels + + xFilter * inputChannels + + cInput; + } + else + { + filterIndex = cOutput * filterWidth * filterHeight * inputChannels + + cInput * filterWidth * filterHeight + + yFilter * filterWidth + + xFilter; + } + } + rFilterDecoder += filterIndex; + float filterValue = rFilterDecoder.Get(); + rFilterDecoder -= filterIndex; + + unsigned int yInput = yOutput * yStride + yFilter * yDilation; + unsigned int xInput = xOutput * xStride + xFilter * xDilation; + + float inputValue; + + // Check if we're in the padding. + if (yInput < paddingTop || yInput >= inputHeight + paddingTop || + xInput < paddingLeft || xInput >= inputWidth + paddingLeft ) + { + inputValue = 0.0f; + } + else + { + unsigned int inputIndex; + + if (dataLayout == DataLayout::NHWC) + { + inputIndex = batchIdx * inputHeight * inputWidth * inputChannels + + (yInput - paddingTop) * inputWidth * inputChannels + + (xInput - paddingLeft) * inputChannels + + cInput; + } + else + { + inputIndex = batchIdx * inputWidth * inputHeight * inputChannels + + inputWidth * inputHeight * cInput + + inputWidth * (yInput - paddingTop) + + xInput - paddingLeft; + } + rInputDecoder += inputIndex; + inputValue = rInputDecoder.Get(); + rInputDecoder -= inputIndex; + } + sum += filterValue * inputValue; + } + } + } + + if (biasEnabled) + { + *pBiasDecoder += cOutput; + sum += pBiasDecoder->Get(); + *pBiasDecoder -= cOutput; + } + unsigned int outIdx = GetOffset(dataLayout, rOutputShape, batchIdx, cOutput, yOutput, xOutput); + + rOutputEncoder += outIdx; + rOutputEncoder.Set(sum); + rOutputEncoder -= outIdx; + } + } + } + } +} + } //namespace armnn -- cgit v1.2.1