aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--1.0/HalPolicy.cpp18
1 files changed, 9 insertions, 9 deletions
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp
index 719d1a24..d66f483d 100644
--- a/1.0/HalPolicy.cpp
+++ b/1.0/HalPolicy.cpp
@@ -471,7 +471,6 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model&
// ArmNN does not currently support non-fixed weights or bias
// Find the shape of the weights tensor. In AndroidNN this will be [ 1, H, W, I * M ]
- // which is equal to [ M, H, W, I ]
const Operand* weightsOperand = GetInputOperand(operation, 1, model);
if (weightsOperand == nullptr)
@@ -480,18 +479,19 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model&
}
// Reinterpret weight data as [ H, W, I, M ]
- armnn::TensorShape weightsShape({ weightsOperand->dimensions[1], weightsOperand->dimensions[2],
+ armnn::TensorShape weightsShape({ weightsOperand->dimensions[1],
+ weightsOperand->dimensions[2],
inputInfo.GetShape()[3],
weightsOperand->dimensions[3] / inputInfo.GetShape()[3] });
- // Swizzle weight data [ H, W, I, M ] -> [ M, H, W, I ]
- const armnn::PermutationVector HWIMToMHWI = { 1U, 2U, 3U, 0U };
+ // Swizzle weight data [ H, W, I, M ] -> [ M, I, H, W ]
+ const armnn::PermutationVector HWIMToMIHW = { 2U, 3U, 1U, 0U };
- ConstTensorPin weightsPin =
- ConvertOperationInputToConstTensorPin(operation, 1, model, data, HWIMToMHWI, &weightsShape);
+ const ConstTensorPin weightsPin = ConvertOperationInputToConstTensorPin(operation, 1, model, data,
+ HWIMToMIHW, &weightsShape);
// Bias is a 1D tensor
- ConstTensorPin biasPin = ConvertOperationInputToConstTensorPin(operation, 2, model, data);
+ const ConstTensorPin biasPin = ConvertOperationInputToConstTensorPin(operation, 2, model, data);
if (!weightsPin.IsValid() || !biasPin.IsValid())
{
@@ -530,8 +530,8 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model&
return Fail("%s: Operation has invalid inputs", __func__);
}
- const uint32_t kernelX = weights.GetShape()[2];
- const uint32_t kernelY = weights.GetShape()[1];
+ const uint32_t kernelX = weights.GetShape()[3];
+ const uint32_t kernelY = weights.GetShape()[2];
const uint32_t inputX = inputInfo.GetShape()[2];
const uint32_t inputY = inputInfo.GetShape()[1];