aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadUtils.cpp
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2018-12-18 09:26:39 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-01-04 17:28:07 +0000
commit747ef82c88f9afe14a8b80b6b3b34118353e97f2 (patch)
treea29ac33b84fb96a41103a0a97327189495374cc9 /src/backends/backendsCommon/WorkloadUtils.cpp
parent760892724d131c7da4b9baad05cddd49276ad6bb (diff)
downloadarmnn-747ef82c88f9afe14a8b80b6b3b34118353e97f2.tar.gz
MLCE-77 Depthwise Convolution with depth multiplier > 1 doesn't work
* Unified ArmNN's weight format to [ M, I, H, W ] for the depthwise convolution * Added conversion utilities to permute/reshape the weights as appropriate when using CL and Neon backends * Updated the reference implementation of the convolution * Updated the relevant unit tests accordingly !android-nn-driver:459 Change-Id: I07d0818efa9d1ca1e5dad82983aac1fe78eadb18
Diffstat (limited to 'src/backends/backendsCommon/WorkloadUtils.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadUtils.cpp111
1 files changed, 111 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadUtils.cpp b/src/backends/backendsCommon/WorkloadUtils.cpp
new file mode 100644
index 0000000000..fa387a7a0b
--- /dev/null
+++ b/src/backends/backendsCommon/WorkloadUtils.cpp
@@ -0,0 +1,111 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "WorkloadUtils.hpp"
+
+namespace armnn
+{
+
+armnn::ConstTensor PermuteTensor(const ConstCpuTensorHandle* tensor,
+ const PermutationVector& permutationVector,
+ void* permuteBuffer)
+{
+ BOOST_ASSERT_MSG(tensor, "Invalid input tensor");
+ BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer");
+
+ TensorInfo tensorInfo = tensor->GetTensorInfo();
+
+ if (permutationVector.GetSize() > 0)
+ {
+ tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector);
+ armnnUtils::Permute(tensorInfo.GetShape(), permutationVector,
+ tensor->GetConstTensor<void>(), permuteBuffer,
+ GetDataTypeSize(tensorInfo.GetDataType()));
+ }
+ else
+ {
+ ::memcpy(permuteBuffer, tensor->GetConstTensor<void>(), tensorInfo.GetNumBytes());
+ }
+
+ return ConstTensor(tensorInfo, permuteBuffer);
+}
+
+void ReshapeWeightsForAcl(TensorInfo& weightInfo, DataLayout dataLayout)
+{
+ // Reshape the weights in-place
+ const TensorShape& weightShape = weightInfo.GetShape();
+ switch (dataLayout)
+ {
+ case DataLayout::NHWC:
+ // The data layout is NHWC, reshape from [ H, W, I, M ] to [ 1, H, W, I * M ]
+ weightInfo.SetShape({ 1,
+ weightShape[0],
+ weightShape[1],
+ weightShape[2] * weightShape[3] });
+ break;
+ case DataLayout::NCHW:
+ default:
+ // The data layout is NCHW, reshape from [ M, I, H, W ] to [ 1, I * M, H, W, ]
+ weightInfo.SetShape({ 1,
+ weightShape[0] * weightShape[1],
+ weightShape[2],
+ weightShape[3] });
+ break;
+ }
+}
+
+TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout)
+{
+ // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
+ // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
+
+ // 1. Permute the weights if necessary
+ // If the data layout is NCHW no permutation is necessary, as a reshape to [ 1, I * M, H, W ] can be better done
+ // starting from the current shape of [ M, I, H, W ]
+ TensorInfo weightPermutedInfo(weightInfo);
+ if (dataLayout == DataLayout::NHWC)
+ {
+ // The data layout is NHWC, then permute the weights from [ M, I, H, W ] to [ H, W, I, M ]
+ PermutationVector permutationVector{ 3, 2, 0, 1 };
+ weightPermutedInfo = armnnUtils::Permuted(weightInfo, permutationVector);
+ }
+
+ // 2. Reshape the weights
+ ReshapeWeightsForAcl(weightPermutedInfo, dataLayout);
+
+ // 3. Return the permuted weight info
+ return weightPermutedInfo;
+}
+
+armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstCpuTensorHandle* weightTensor,
+ DataLayout dataLayout,
+ void* permuteBuffer)
+{
+ BOOST_ASSERT_MSG(weightTensor, "Invalid input tensor");
+ BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer");
+
+ // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
+ // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
+
+ // 1. Permute the weights if necessary
+ // If the data layout is NCHW no permutation is necessary, as a reshape to [ 1, I * M, H, W ] can be better done
+ // starting from the current shape of [ M, I, H, W ]
+ // If no permutation is necessary, leave the permutation vector empty
+ PermutationVector permutationVector{};
+ if (dataLayout == DataLayout::NHWC)
+ {
+ // The data layout is NHWC, then permute the weights from [ M, I, H, W ] to [ H, W, I, M ]
+ permutationVector = { 3, 2, 0, 1 };
+ }
+ ConstTensor weightPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer);
+
+ // 2. Reshape the weights
+ ReshapeWeightsForAcl(weightPermuted.GetInfo(), dataLayout);
+
+ // 3. Return both the tensor and the allocated storage to ensure that the data stays alive
+ return weightPermuted;
+}
+
+} // namespace armnn