aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2018-12-18 09:26:39 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-01-04 17:28:07 +0000
commit747ef82c88f9afe14a8b80b6b3b34118353e97f2 (patch)
treea29ac33b84fb96a41103a0a97327189495374cc9 /src/armnnTfParser/TfParser.cpp
parent760892724d131c7da4b9baad05cddd49276ad6bb (diff)
downloadarmnn-747ef82c88f9afe14a8b80b6b3b34118353e97f2.tar.gz
MLCE-77 Depthwise Convolution with depth multiplier > 1 doesn't work
* Unified ArmNN's weight format to [ M, I, H, W ] for the depthwise convolution * Added conversion utilities to permute/reshape the weights as appropriate when using CL and Neon backends * Updated the reference implementation of the convolution * Updated the relevant unit tests accordingly !android-nn-driver:459 Change-Id: I07d0818efa9d1ca1e5dad82983aac1fe78eadb18
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r--src/armnnTfParser/TfParser.cpp16
1 files changed, 6 insertions, 10 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 7f04757b75..7a213c0909 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1338,13 +1338,9 @@ ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& n
uint32_t inputWidth = inputTensorInfo.GetShape()[dataLayoutIndexed.GetWidthIndex()];
// Mappings from TensorFlow filter tensors to the ArmNN filter tensors.
- // Tensorflow weights are [H, W, In, Out].
- // ArmNN weights have to be [Out, H, W, In] when the data layout is NHWC,
- // and [Out, In, H, W] when the data layout is NCHW.
- PermutationVector permutationVector =
- dataLayout == DataLayout::NHWC ?
- std::initializer_list<unsigned int>{ 1, 2, 3, 0 } : // NHWC: [H, W, In, Out] -> [Out, H, W, In]
- std::initializer_list<unsigned int>{ 2, 3, 1, 0 }; // NCHW: [H, W, In, Out] -> [Out, In, H, W]
+ // Tensorflow weights come in the format [H, W, I, M].
+ // ArmNN weights have to be [M, I, H, W].
+ PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W]
// Swizzle the tensor using the given permutation vector.
const TensorInfo& weightTensorInfo = weightNode->GetTensorInfo();
@@ -1358,8 +1354,8 @@ ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& n
// Create a weight tensor with the newly swizzled data.
ConstTensor weightTensor(weightTensorSwizzledInfo, weightTensorSwizzledData);
- uint32_t weightHeight = weightTensor.GetShape()[dataLayoutIndexed.GetHeightIndex()];
- uint32_t weightWidth = weightTensor.GetShape()[dataLayoutIndexed.GetWidthIndex()];
+ uint32_t weightHeight = weightTensor.GetShape()[2];
+ uint32_t weightWidth = weightTensor.GetShape()[3];
bool padding = false;
TensorInfo outputInfo;
@@ -1393,7 +1389,7 @@ ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& n
outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0],
outputHeight,
outputWidth,
- weightTensor.GetShape()[0] * weightTensor.GetShape()[3]},
+ weightTensor.GetShape()[0] * weightTensor.GetShape()[1]},
DataType::Float32);
break;
case DataLayout::NCHW: