aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2018-12-18 09:32:02 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-01-02 16:09:56 +0000
commit361ccc845c46478dc426a3bb67542b3e3a69db11 (patch)
treed15fbd9c505db337605f3fb28ec5d8f0430bd110
parentc1944473091003cc17af8dbbb850a5106132197d (diff)
downloadandroid-nn-driver-361ccc845c46478dc426a3bb67542b3e3a69db11.tar.gz
MLCE-77 Depthwise Convolution with depth multiplier > 1 doesn't work
* Changed the weight swizzling to [ M, I, H, W ] as now required by ArmNN !armnn:460 Change-Id: I7c25e7ab3e1efc47d7db3f2b57e17382ea8b36cf
-rw-r--r--1.0/HalPolicy.cpp18
1 files changed, 9 insertions, 9 deletions
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp
index 719d1a24..d66f483d 100644
--- a/1.0/HalPolicy.cpp
+++ b/1.0/HalPolicy.cpp
@@ -471,7 +471,6 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model&
// ArmNN does not currently support non-fixed weights or bias
// Find the shape of the weights tensor. In AndroidNN this will be [ 1, H, W, I * M ]
- // which is equal to [ M, H, W, I ]
const Operand* weightsOperand = GetInputOperand(operation, 1, model);
if (weightsOperand == nullptr)
@@ -480,18 +479,19 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model&
}
// Reinterpret weight data as [ H, W, I, M ]
- armnn::TensorShape weightsShape({ weightsOperand->dimensions[1], weightsOperand->dimensions[2],
+ armnn::TensorShape weightsShape({ weightsOperand->dimensions[1],
+ weightsOperand->dimensions[2],
inputInfo.GetShape()[3],
weightsOperand->dimensions[3] / inputInfo.GetShape()[3] });
- // Swizzle weight data [ H, W, I, M ] -> [ M, H, W, I ]
- const armnn::PermutationVector HWIMToMHWI = { 1U, 2U, 3U, 0U };
+ // Swizzle weight data [ H, W, I, M ] -> [ M, I, H, W ]
+ const armnn::PermutationVector HWIMToMIHW = { 2U, 3U, 1U, 0U };
- ConstTensorPin weightsPin =
- ConvertOperationInputToConstTensorPin(operation, 1, model, data, HWIMToMHWI, &weightsShape);
+ const ConstTensorPin weightsPin = ConvertOperationInputToConstTensorPin(operation, 1, model, data,
+ HWIMToMIHW, &weightsShape);
// Bias is a 1D tensor
- ConstTensorPin biasPin = ConvertOperationInputToConstTensorPin(operation, 2, model, data);
+ const ConstTensorPin biasPin = ConvertOperationInputToConstTensorPin(operation, 2, model, data);
if (!weightsPin.IsValid() || !biasPin.IsValid())
{
@@ -530,8 +530,8 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model&
return Fail("%s: Operation has invalid inputs", __func__);
}
- const uint32_t kernelX = weights.GetShape()[2];
- const uint32_t kernelY = weights.GetShape()[1];
+ const uint32_t kernelX = weights.GetShape()[3];
+ const uint32_t kernelY = weights.GetShape()[2];
const uint32_t inputX = inputInfo.GetShape()[2];
const uint32_t inputY = inputInfo.GetShape()[1];