diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2018-12-18 09:26:39 +0000 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-01-04 17:28:07 +0000 |
commit | 747ef82c88f9afe14a8b80b6b3b34118353e97f2 (patch) | |
tree | a29ac33b84fb96a41103a0a97327189495374cc9 /src/armnnUtils | |
parent | 760892724d131c7da4b9baad05cddd49276ad6bb (diff) | |
download | armnn-747ef82c88f9afe14a8b80b6b3b34118353e97f2.tar.gz |
MLCE-77 Depthwise Convolution with depth multiplier > 1 doesn't work
* Unified ArmNN's weight format to [ M, I, H, W ] for the depthwise convolution
* Added conversion utilities to permute/reshape the weights as appropriate
when using CL and Neon backends
* Updated the reference implementation of the convolution
* Updated the relevant unit tests accordingly
!android-nn-driver:459
Change-Id: I07d0818efa9d1ca1e5dad82983aac1fe78eadb18
Diffstat (limited to 'src/armnnUtils')
-rw-r--r-- | src/armnnUtils/ParserPrototxtFixture.hpp | 2 | ||||
-rw-r--r-- | src/armnnUtils/Permute.cpp | 57 | ||||
-rw-r--r-- | src/armnnUtils/Permute.hpp | 5 |
3 files changed, 61 insertions, 3 deletions
diff --git a/src/armnnUtils/ParserPrototxtFixture.hpp b/src/armnnUtils/ParserPrototxtFixture.hpp index fa21aba479..acb8f82c4d 100644 --- a/src/armnnUtils/ParserPrototxtFixture.hpp +++ b/src/armnnUtils/ParserPrototxtFixture.hpp @@ -14,8 +14,6 @@ #include <Network.hpp> #include <VerificationHelpers.hpp> -#include <backendsCommon/BackendRegistry.hpp> - #include <boost/format.hpp> #include <string> diff --git a/src/armnnUtils/Permute.cpp b/src/armnnUtils/Permute.cpp index 61f4e0e644..6deff90168 100644 --- a/src/armnnUtils/Permute.cpp +++ b/src/armnnUtils/Permute.cpp @@ -9,6 +9,7 @@ #include <armnn/Tensor.hpp> #include <cassert> +#include <cstring> namespace { @@ -46,10 +47,29 @@ public: Unroll(0, srcData, dstData, srcEnd, dstEnd); } + void Unroll(const void* srcData, void* dstData, size_t dataTypeSize) + { + assert(srcData); + assert(dstData); + assert(dataTypeSize > 0); + + const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData); + unsigned char* dstDataPtr = reinterpret_cast<unsigned char*>(dstData); + + const unsigned char* const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize; + unsigned char* const dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize; + + Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize); + } + private: template <typename T> void Unroll(size_type dimension, const T* srcData, T* dstData, const T* srcEnd, T* dstEnd) { + assert(srcData); + assert(dstData); + assert(srcEnd); + assert(dstEnd); assert(srcData < srcEnd); assert(dstData < dstEnd); @@ -69,6 +89,35 @@ private: } } + void Unroll(size_type dimension, + const unsigned char* srcData, unsigned char* dstData, + const unsigned char* srcEnd, unsigned char* dstEnd, + size_t dataTypeSize) + { + assert(srcData); + assert(dstData); + assert(srcEnd); + assert(dstEnd); + assert(srcData < srcEnd); + assert(dstData < dstEnd); + assert(dataTypeSize > 0); + + if (dimension >= m_DstShape.GetNumDimensions()) + { + ::memcpy(dstData, srcData, dataTypeSize); + } + else + { + for (size_type i = 0; i < m_DstShape[dimension]; i++) + { + Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize); + + srcData += m_SrcStrides[dimension] * dataTypeSize; + dstData += m_DstStrides[dimension] * dataTypeSize; + } + } + } + armnn::TensorShape m_DstShape; std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides; std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides; @@ -102,6 +151,12 @@ armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::Permutati return outInfo; } +void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, + const void* src, void* dst, size_t dataTypeSize) +{ + PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize); +} + template <typename T> void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, const T* src, T* dst) { @@ -117,5 +172,7 @@ template void Permute(const armnn::TensorShape& dstShape, const armnn::Permutati const uint8_t* src, uint8_t* dst); template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, const int32_t* src, int32_t* dst); +template void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, + const bool* src, bool* dst); } // namespace armnnUtils diff --git a/src/armnnUtils/Permute.hpp b/src/armnnUtils/Permute.hpp index 700ddc72ce..4e4319822b 100644 --- a/src/armnnUtils/Permute.hpp +++ b/src/armnnUtils/Permute.hpp @@ -14,7 +14,10 @@ armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::Per armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings); +void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, + const void* src, void* dst, size_t dataTypeSize); + template <typename T> void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, const T* src, T* dst); -} // namespace armnnUtils
\ No newline at end of file +} // namespace armnnUtils |