diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2018-12-18 09:32:02 +0000 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-01-02 16:09:56 +0000 |
commit | 361ccc845c46478dc426a3bb67542b3e3a69db11 (patch) | |
tree | d15fbd9c505db337605f3fb28ec5d8f0430bd110 | |
parent | c1944473091003cc17af8dbbb850a5106132197d (diff) | |
download | android-nn-driver-361ccc845c46478dc426a3bb67542b3e3a69db11.tar.gz |
MLCE-77 Depthwise Convolution with depth multiplier > 1 doesn't work
* Changed the weight swizzling to [ M, I, H, W ] as now required by ArmNN
!armnn:460
Change-Id: I7c25e7ab3e1efc47d7db3f2b57e17382ea8b36cf
-rw-r--r-- | 1.0/HalPolicy.cpp | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp index 719d1a24..d66f483d 100644 --- a/1.0/HalPolicy.cpp +++ b/1.0/HalPolicy.cpp @@ -471,7 +471,6 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model& // ArmNN does not currently support non-fixed weights or bias // Find the shape of the weights tensor. In AndroidNN this will be [ 1, H, W, I * M ] - // which is equal to [ M, H, W, I ] const Operand* weightsOperand = GetInputOperand(operation, 1, model); if (weightsOperand == nullptr) @@ -480,18 +479,19 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model& } // Reinterpret weight data as [ H, W, I, M ] - armnn::TensorShape weightsShape({ weightsOperand->dimensions[1], weightsOperand->dimensions[2], + armnn::TensorShape weightsShape({ weightsOperand->dimensions[1], + weightsOperand->dimensions[2], inputInfo.GetShape()[3], weightsOperand->dimensions[3] / inputInfo.GetShape()[3] }); - // Swizzle weight data [ H, W, I, M ] -> [ M, H, W, I ] - const armnn::PermutationVector HWIMToMHWI = { 1U, 2U, 3U, 0U }; + // Swizzle weight data [ H, W, I, M ] -> [ M, I, H, W ] + const armnn::PermutationVector HWIMToMIHW = { 2U, 3U, 1U, 0U }; - ConstTensorPin weightsPin = - ConvertOperationInputToConstTensorPin(operation, 1, model, data, HWIMToMHWI, &weightsShape); + const ConstTensorPin weightsPin = ConvertOperationInputToConstTensorPin(operation, 1, model, data, + HWIMToMIHW, &weightsShape); // Bias is a 1D tensor - ConstTensorPin biasPin = ConvertOperationInputToConstTensorPin(operation, 2, model, data); + const ConstTensorPin biasPin = ConvertOperationInputToConstTensorPin(operation, 2, model, data); if (!weightsPin.IsValid() || !biasPin.IsValid()) { @@ -530,8 +530,8 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model& return Fail("%s: Operation has invalid inputs", __func__); } - const uint32_t kernelX = weights.GetShape()[2]; - const uint32_t kernelY = weights.GetShape()[1]; + const uint32_t kernelX = weights.GetShape()[3]; + const uint32_t kernelY = weights.GetShape()[2]; const uint32_t inputX = inputInfo.GetShape()[2]; const uint32_t inputY = inputInfo.GetShape()[1]; |