aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r--src/armnnTfParser/TfParser.cpp16
1 files changed, 6 insertions, 10 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 7f04757b75..7a213c0909 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1338,13 +1338,9 @@ ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& n
uint32_t inputWidth = inputTensorInfo.GetShape()[dataLayoutIndexed.GetWidthIndex()];
// Mappings from TensorFlow filter tensors to the ArmNN filter tensors.
- // Tensorflow weights are [H, W, In, Out].
- // ArmNN weights have to be [Out, H, W, In] when the data layout is NHWC,
- // and [Out, In, H, W] when the data layout is NCHW.
- PermutationVector permutationVector =
- dataLayout == DataLayout::NHWC ?
- std::initializer_list<unsigned int>{ 1, 2, 3, 0 } : // NHWC: [H, W, In, Out] -> [Out, H, W, In]
- std::initializer_list<unsigned int>{ 2, 3, 1, 0 }; // NCHW: [H, W, In, Out] -> [Out, In, H, W]
+ // Tensorflow weights come in the format [H, W, I, M].
+ // ArmNN weights have to be [M, I, H, W].
+ PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W]
// Swizzle the tensor using the given permutation vector.
const TensorInfo& weightTensorInfo = weightNode->GetTensorInfo();
@@ -1358,8 +1354,8 @@ ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& n
// Create a weight tensor with the newly swizzled data.
ConstTensor weightTensor(weightTensorSwizzledInfo, weightTensorSwizzledData);
- uint32_t weightHeight = weightTensor.GetShape()[dataLayoutIndexed.GetHeightIndex()];
- uint32_t weightWidth = weightTensor.GetShape()[dataLayoutIndexed.GetWidthIndex()];
+ uint32_t weightHeight = weightTensor.GetShape()[2];
+ uint32_t weightWidth = weightTensor.GetShape()[3];
bool padding = false;
TensorInfo outputInfo;
@@ -1393,7 +1389,7 @@ ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& n
outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0],
outputHeight,
outputWidth,
- weightTensor.GetShape()[0] * weightTensor.GetShape()[3]},
+ weightTensor.GetShape()[0] * weightTensor.GetShape()[1]},
DataType::Float32);
break;
case DataLayout::NCHW: