aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2018-12-18 09:26:39 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-01-04 17:28:07 +0000
commit747ef82c88f9afe14a8b80b6b3b34118353e97f2 (patch)
treea29ac33b84fb96a41103a0a97327189495374cc9 /src/armnnUtils
parent760892724d131c7da4b9baad05cddd49276ad6bb (diff)
downloadarmnn-747ef82c88f9afe14a8b80b6b3b34118353e97f2.tar.gz
MLCE-77 Depthwise Convolution with depth multiplier > 1 doesn't work
* Unified ArmNN's weight format to [ M, I, H, W ] for the depthwise convolution * Added conversion utilities to permute/reshape the weights as appropriate when using CL and Neon backends * Updated the reference implementation of the convolution * Updated the relevant unit tests accordingly !android-nn-driver:459 Change-Id: I07d0818efa9d1ca1e5dad82983aac1fe78eadb18
Diffstat (limited to 'src/armnnUtils')
-rw-r--r--src/armnnUtils/ParserPrototxtFixture.hpp2
-rw-r--r--src/armnnUtils/Permute.cpp57
-rw-r--r--src/armnnUtils/Permute.hpp5
3 files changed, 61 insertions, 3 deletions
diff --git a/src/armnnUtils/ParserPrototxtFixture.hpp b/src/armnnUtils/ParserPrototxtFixture.hpp
index fa21aba479..acb8f82c4d 100644
--- a/src/armnnUtils/ParserPrototxtFixture.hpp
+++ b/src/armnnUtils/ParserPrototxtFixture.hpp
@@ -14,8 +14,6 @@
#include <Network.hpp>
#include <VerificationHelpers.hpp>
-#include <backendsCommon/BackendRegistry.hpp>
-
#include <boost/format.hpp>
#include <string>
diff --git a/src/armnnUtils/Permute.cpp b/src/armnnUtils/Permute.cpp
index 61f4e0e644..6deff90168 100644
--- a/src/armnnUtils/Permute.cpp
+++ b/src/armnnUtils/Permute.cpp
@@ -9,6 +9,7 @@
#include <armnn/Tensor.hpp>
#include <cassert>
+#include <cstring>
namespace
{
@@ -46,10 +47,29 @@ public:
Unroll(0, srcData, dstData, srcEnd, dstEnd);
}
+ void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
+ {
+ assert(srcData);
+ assert(dstData);
+ assert(dataTypeSize > 0);
+
+ const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
+ unsigned char* dstDataPtr = reinterpret_cast<unsigned char*>(dstData);
+
+ const unsigned char* const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
+ unsigned char* const dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
+
+ Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
+ }
+
private:
template <typename T>
void Unroll(size_type dimension, const T* srcData, T* dstData, const T* srcEnd, T* dstEnd)
{
+ assert(srcData);
+ assert(dstData);
+ assert(srcEnd);
+ assert(dstEnd);
assert(srcData < srcEnd);
assert(dstData < dstEnd);
@@ -69,6 +89,35 @@ private:
}
}
+ void Unroll(size_type dimension,
+ const unsigned char* srcData, unsigned char* dstData,
+ const unsigned char* srcEnd, unsigned char* dstEnd,
+ size_t dataTypeSize)
+ {
+ assert(srcData);
+ assert(dstData);
+ assert(srcEnd);
+ assert(dstEnd);
+ assert(srcData < srcEnd);
+ assert(dstData < dstEnd);
+ assert(dataTypeSize > 0);
+
+ if (dimension >= m_DstShape.GetNumDimensions())
+ {
+ ::memcpy(dstData, srcData, dataTypeSize);
+ }
+ else
+ {
+ for (size_type i = 0; i < m_DstShape[dimension]; i++)
+ {
+ Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
+
+ srcData += m_SrcStrides[dimension] * dataTypeSize;
+ dstData += m_DstStrides[dimension] * dataTypeSize;
+ }
+ }
+ }
+
armnn::TensorShape m_DstShape;
std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
@@ -102,6 +151,12 @@ armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::Permutati
return outInfo;
}
+void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
+ const void* src, void* dst, size_t dataTypeSize)
+{
+ PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize);
+}
+
template <typename T>
void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, const T* src, T* dst)
{
@@ -117,5 +172,7 @@ template void Permute(const armnn::TensorShape& dstShape, const armnn::Permutati
const uint8_t* src, uint8_t* dst);
template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
const int32_t* src, int32_t* dst);
+template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
+ const bool* src, bool* dst);
} // namespace armnnUtils
diff --git a/src/armnnUtils/Permute.hpp b/src/armnnUtils/Permute.hpp
index 700ddc72ce..4e4319822b 100644
--- a/src/armnnUtils/Permute.hpp
+++ b/src/armnnUtils/Permute.hpp
@@ -14,7 +14,10 @@ armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::Per
armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings);
+void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
+ const void* src, void* dst, size_t dataTypeSize);
+
template <typename T>
void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, const T* src, T* dst);
-} // namespace armnnUtils \ No newline at end of file
+} // namespace armnnUtils