aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp23
1 files changed, 14 insertions, 9 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 560cdf1779..593f3eb02d 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -301,7 +301,8 @@ void CalcPadding(uint32_t inputSize,
}
}
-armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::vector<unsigned int>& shapes)
+armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::vector<unsigned int>& shapes,
+ const armnn::PermutationVector& dimensionMappings = {0, 1, 2, 3})
{
armnn::DataType type;
CHECK_TENSOR_PTR(tensorPtr);
@@ -317,10 +318,12 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::
case tflite::TensorType_INT8:
if (tensorPtr->quantization->zero_point.size() == 1 && tensorPtr->quantization->zero_point[0] != 0)
{
+ // Per-tensor
type = armnn::DataType::QAsymmS8;
}
else
{
+ // Per-channel
type = armnn::DataType::QSymmS8;
}
break;
@@ -388,12 +391,13 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::
tensorPtr->quantization->scale.end(),
std::back_inserter(quantizationScales));
- // QSymm Per-axis
+ // QSymmS8 Per-axis
armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
safeShape.data(),
type,
quantizationScales,
- boost::numeric_cast<unsigned int>(tensorPtr->quantization->quantized_dimension));
+ dimensionMappings[boost::numeric_cast<unsigned int>(
+ tensorPtr->quantization->quantized_dimension)]);
return result;
}
@@ -409,10 +413,11 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::
}
}
-armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr)
+armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr,
+ const armnn::PermutationVector& dimensionMappings = {0, 1, 2, 3})
{
auto const & dimensions = AsUnsignedVector(tensorPtr->shape);
- return ToTensorInfo(tensorPtr, dimensions);
+ return ToTensorInfo(tensorPtr, dimensions, dimensionMappings);
}
template<typename T>
@@ -905,8 +910,11 @@ void TfLiteParser::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorInd
desc.m_DilationX = CHECKED_NON_NEGATIVE(options->dilation_w_factor);
desc.m_DilationY = CHECKED_NON_NEGATIVE(options->dilation_h_factor);
+ // Mappings from TensorflowLite filter tensors to the ArmNN filter tensors (ArmNN weights have to be [M, I, H, W])
+ PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W]
+
armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
- armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1]);
+ armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1], permutationVector);
// Assuming input is NHWC
unsigned int inputHeight = inputTensorInfo.GetShape()[1];
@@ -922,9 +930,6 @@ void TfLiteParser::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorInd
inputTensorInfo.GetShape()[3],
filterTensorInfo.GetShape()[3] / inputTensorInfo.GetShape()[3] });
- // Mappings from TensorflowLite filter tensors to the ArmNN filter tensors (ArmNN weights have to be [M, I, H, W])
- PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W]
-
CalcPadding(inputHeight, filterHeight, desc.m_StrideY,
desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, options->padding);
CalcPadding(inputWidth, filterWidth, desc.m_StrideX,