aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/ConvImpl.hpp
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2018-11-28 16:22:36 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2018-12-05 13:48:43 +0000
commit4631582a4c8b92917633d0af4ebcc8fff2abd04a (patch)
tree15f237e78af14394486699dedb834531af207067 /src/backends/reference/workloads/ConvImpl.hpp
parentc2130a070e6a9196d193c93a02b5f118810dd59a (diff)
downloadarmnn-4631582a4c8b92917633d0af4ebcc8fff2abd04a.tar.gz
IVGCVSW-2264 Remove input swizzling from ParseConv2D in the TF parser
* Removed the input swizzling when the data layout is NHWC * Permuting weights depending on the data layout used * Added getter methods to ParsedConstTfOperation to get the tensor info and the storage memory area, needed for swizzling the weights * Added unit tests for both NHWC and NCHW data layouts Change-Id: I6543900c594417df630b2663d8551158b93b7836
Diffstat (limited to 'src/backends/reference/workloads/ConvImpl.hpp')
-rw-r--r--src/backends/reference/workloads/ConvImpl.hpp11
1 files changed, 7 insertions, 4 deletions
diff --git a/src/backends/reference/workloads/ConvImpl.hpp b/src/backends/reference/workloads/ConvImpl.hpp
index b8e2deaa9c..704bc368d2 100644
--- a/src/backends/reference/workloads/ConvImpl.hpp
+++ b/src/backends/reference/workloads/ConvImpl.hpp
@@ -15,6 +15,8 @@
#include <boost/assert.hpp>
#include <boost/numeric/conversion/cast.hpp>
+#include <DataLayoutIndexed.hpp>
+
#include <cmath>
#include <limits>
@@ -74,6 +76,7 @@ static void ConvImpl(ConvData data,
data.m_Parameters.m_DataLayout);
const armnnUtils::DataLayoutIndexed dataLayoutIndexed(data.m_Parameters.m_DataLayout);
+
const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
const unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex();
const unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex();
@@ -91,10 +94,10 @@ static void ConvImpl(ConvData data,
unsigned int heightFilter = filterInfo.GetShape()[heightIndex];
unsigned int widthFilter = filterInfo.GetShape()[widthIndex];
- unsigned int paddingTop = data.m_Parameters.m_PadTop;
+ unsigned int paddingTop = data.m_Parameters.m_PadTop;
unsigned int paddingLeft = data.m_Parameters.m_PadLeft;
- unsigned int hStride = data.m_Parameters.m_StrideY;
- unsigned int xStride = data.m_Parameters.m_StrideX;
+ unsigned int xStride = data.m_Parameters.m_StrideX;
+ unsigned int yStride = data.m_Parameters.m_StrideY;
// The world's least efficient convolution.
for (unsigned int batchIdx = 0; batchIdx < batchSize; batchIdx++)
@@ -168,7 +171,7 @@ static void ConvImpl(ConvData data,
AccumulatorType filterValue = filterData[filterIndex] -
boost::numeric_cast<AccumulatorType>(filterOffset);
- unsigned int yInput = yOutput * hStride + yFilter;
+ unsigned int yInput = yOutput * yStride + yFilter;
unsigned int xInput = xOutput * xStride + xFilter;
AccumulatorType inputValue;