aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/ConvImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/ConvImpl.hpp')
-rw-r--r--src/backends/reference/workloads/ConvImpl.hpp75
1 files changed, 58 insertions, 17 deletions
diff --git a/src/backends/reference/workloads/ConvImpl.hpp b/src/backends/reference/workloads/ConvImpl.hpp
index 60a3622c55..4b15c1da6d 100644
--- a/src/backends/reference/workloads/ConvImpl.hpp
+++ b/src/backends/reference/workloads/ConvImpl.hpp
@@ -6,6 +6,7 @@
#pragma once
#include "RefWorkloadUtils.hpp"
+#include "TensorBufferArrayView.hpp"
#include <armnn/Tensor.hpp>
@@ -66,6 +67,10 @@ static void ConvImpl(ConvData data,
const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]);
const TensorInfo& outputInfo0 = GetTensorInfo(data.m_Outputs[0]);
+ TensorBufferArrayView<InputType> output(outputInfo0.GetShape(),
+ GetOutputTensorData<InputType>(0, data),
+ data.m_Parameters.m_DataLayout);
+
const DataLayoutIndexed dataLayoutIndexed(data.m_Parameters.m_DataLayout);
const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
const unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex();
@@ -123,18 +128,41 @@ static void ConvImpl(ConvData data,
// Since dimensionality of kernel depends on depthwiseness, so does index.
if (depthwise)
{
- filterIndex = depthwiseMultiplierIdx * widthFilter * heightFilter * channelsInput +
- cInput * widthFilter * heightFilter +
- yFilter * widthFilter +
- xFilter;
+ if (data.m_Parameters.m_DataLayout == DataLayout::NHWC)
+ {
+ filterIndex = depthwiseMultiplierIdx * heightFilter * widthFilter
+ * channelsInput +
+ yFilter * widthFilter * channelsInput +
+ xFilter * channelsInput +
+ cInput;
+ }
+ else
+ {
+ filterIndex = depthwiseMultiplierIdx * widthFilter * heightFilter
+ * channelsInput +
+ cInput * widthFilter * heightFilter +
+ yFilter * widthFilter +
+ xFilter;
+ }
}
else
{
- filterIndex = cOutput * widthFilter * heightFilter * channelsInput +
- cInput * widthFilter * heightFilter +
- yFilter * widthFilter +
- xFilter;
+ if (data.m_Parameters.m_DataLayout == DataLayout::NHWC)
+ {
+ filterIndex = cOutput * heightFilter * widthFilter * channelsInput +
+ yFilter * widthFilter * channelsInput +
+ xFilter * channelsInput +
+ cInput;
+ }
+ else
+ {
+ filterIndex = cOutput * widthFilter * heightFilter * channelsInput +
+ cInput * widthFilter * heightFilter +
+ yFilter * widthFilter +
+ xFilter;
+ }
}
+
AccumulatorType filterValue = filterData[filterIndex] -
boost::numeric_cast<AccumulatorType>(filterOffset);
@@ -151,11 +179,27 @@ static void ConvImpl(ConvData data,
}
else
{
- inputValue = inputData[batchIdx * widthInput * heightInput * channelsInput +
- widthInput * heightInput * cInput +
- widthInput * (yInput - paddingTop) +
- xInput - paddingLeft] -
- boost::numeric_cast<AccumulatorType>(inputOffset);
+ unsigned int inputIndex;
+
+ if (data.m_Parameters.m_DataLayout == DataLayout::NHWC)
+ {
+ inputIndex = batchIdx * heightInput * widthInput * channelsInput +
+ (yInput - paddingTop) * widthInput * channelsInput +
+ (xInput - paddingLeft) * channelsInput +
+ cInput;
+
+ }
+ else
+ {
+ inputIndex = batchIdx * widthInput * heightInput * channelsInput +
+ widthInput * heightInput * cInput +
+ widthInput * (yInput - paddingTop) +
+ xInput - paddingLeft;
+ }
+
+ inputValue = inputData[inputIndex] -
+ boost::numeric_cast<AccumulatorType>(inputOffset);
+
}
sum += filterValue * inputValue;
}
@@ -179,10 +223,7 @@ static void ConvImpl(ConvData data,
sum = std::min<AccumulatorType>(std::max<AccumulatorType>(sum, 0), 255);
}
- outputData[batchIdx * widthOutput * heightOutput * channelsOutput +
- widthOutput * heightOutput * cOutput +
- widthOutput * yOutput +
- xOutput] = boost::numeric_cast<InputType>(sum);
+ output.Get(batchIdx, cOutput, yOutput, xOutput) = boost::numeric_cast<InputType>(sum);
}
}
}