diff options
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r-- | src/backends/reference/workloads/ConvImpl.hpp | 75 |
1 files changed, 58 insertions, 17 deletions
diff --git a/src/backends/reference/workloads/ConvImpl.hpp b/src/backends/reference/workloads/ConvImpl.hpp index 60a3622c55..4b15c1da6d 100644 --- a/src/backends/reference/workloads/ConvImpl.hpp +++ b/src/backends/reference/workloads/ConvImpl.hpp @@ -6,6 +6,7 @@ #pragma once #include "RefWorkloadUtils.hpp" +#include "TensorBufferArrayView.hpp" #include <armnn/Tensor.hpp> @@ -66,6 +67,10 @@ static void ConvImpl(ConvData data, const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]); const TensorInfo& outputInfo0 = GetTensorInfo(data.m_Outputs[0]); + TensorBufferArrayView<InputType> output(outputInfo0.GetShape(), + GetOutputTensorData<InputType>(0, data), + data.m_Parameters.m_DataLayout); + const DataLayoutIndexed dataLayoutIndexed(data.m_Parameters.m_DataLayout); const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex(); const unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex(); @@ -123,18 +128,41 @@ static void ConvImpl(ConvData data, // Since dimensionality of kernel depends on depthwiseness, so does index. if (depthwise) { - filterIndex = depthwiseMultiplierIdx * widthFilter * heightFilter * channelsInput + - cInput * widthFilter * heightFilter + - yFilter * widthFilter + - xFilter; + if (data.m_Parameters.m_DataLayout == DataLayout::NHWC) + { + filterIndex = depthwiseMultiplierIdx * heightFilter * widthFilter + * channelsInput + + yFilter * widthFilter * channelsInput + + xFilter * channelsInput + + cInput; + } + else + { + filterIndex = depthwiseMultiplierIdx * widthFilter * heightFilter + * channelsInput + + cInput * widthFilter * heightFilter + + yFilter * widthFilter + + xFilter; + } } else { - filterIndex = cOutput * widthFilter * heightFilter * channelsInput + - cInput * widthFilter * heightFilter + - yFilter * widthFilter + - xFilter; + if (data.m_Parameters.m_DataLayout == DataLayout::NHWC) + { + filterIndex = cOutput * heightFilter * widthFilter * channelsInput + + yFilter * widthFilter * channelsInput + + xFilter * channelsInput + + cInput; + } + else + { + filterIndex = cOutput * widthFilter * heightFilter * channelsInput + + cInput * widthFilter * heightFilter + + yFilter * widthFilter + + xFilter; + } } + AccumulatorType filterValue = filterData[filterIndex] - boost::numeric_cast<AccumulatorType>(filterOffset); @@ -151,11 +179,27 @@ static void ConvImpl(ConvData data, } else { - inputValue = inputData[batchIdx * widthInput * heightInput * channelsInput + - widthInput * heightInput * cInput + - widthInput * (yInput - paddingTop) + - xInput - paddingLeft] - - boost::numeric_cast<AccumulatorType>(inputOffset); + unsigned int inputIndex; + + if (data.m_Parameters.m_DataLayout == DataLayout::NHWC) + { + inputIndex = batchIdx * heightInput * widthInput * channelsInput + + (yInput - paddingTop) * widthInput * channelsInput + + (xInput - paddingLeft) * channelsInput + + cInput; + + } + else + { + inputIndex = batchIdx * widthInput * heightInput * channelsInput + + widthInput * heightInput * cInput + + widthInput * (yInput - paddingTop) + + xInput - paddingLeft; + } + + inputValue = inputData[inputIndex] - + boost::numeric_cast<AccumulatorType>(inputOffset); + } sum += filterValue * inputValue; } @@ -179,10 +223,7 @@ static void ConvImpl(ConvData data, sum = std::min<AccumulatorType>(std::max<AccumulatorType>(sum, 0), 255); } - outputData[batchIdx * widthOutput * heightOutput * channelsOutput + - widthOutput * heightOutput * cOutput + - widthOutput * yOutput + - xOutput] = boost::numeric_cast<InputType>(sum); + output.Get(batchIdx, cOutput, yOutput, xOutput) = boost::numeric_cast<InputType>(sum); } } } |