diff options
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 23 |
1 files changed, 14 insertions, 9 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 560cdf1779..593f3eb02d 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -301,7 +301,8 @@ void CalcPadding(uint32_t inputSize, } } -armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::vector<unsigned int>& shapes) +armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::vector<unsigned int>& shapes, + const armnn::PermutationVector& dimensionMappings = {0, 1, 2, 3}) { armnn::DataType type; CHECK_TENSOR_PTR(tensorPtr); @@ -317,10 +318,12 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std:: case tflite::TensorType_INT8: if (tensorPtr->quantization->zero_point.size() == 1 && tensorPtr->quantization->zero_point[0] != 0) { + // Per-tensor type = armnn::DataType::QAsymmS8; } else { + // Per-channel type = armnn::DataType::QSymmS8; } break; @@ -388,12 +391,13 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std:: tensorPtr->quantization->scale.end(), std::back_inserter(quantizationScales)); - // QSymm Per-axis + // QSymmS8 Per-axis armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()), safeShape.data(), type, quantizationScales, - boost::numeric_cast<unsigned int>(tensorPtr->quantization->quantized_dimension)); + dimensionMappings[boost::numeric_cast<unsigned int>( + tensorPtr->quantization->quantized_dimension)]); return result; } @@ -409,10 +413,11 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std:: } } -armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr) +armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, + const armnn::PermutationVector& dimensionMappings = {0, 1, 2, 3}) { auto const & dimensions = AsUnsignedVector(tensorPtr->shape); - return ToTensorInfo(tensorPtr, dimensions); + return ToTensorInfo(tensorPtr, dimensions, dimensionMappings); } template<typename T> @@ -905,8 +910,11 @@ void TfLiteParser::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorInd desc.m_DilationX = CHECKED_NON_NEGATIVE(options->dilation_w_factor); desc.m_DilationY = CHECKED_NON_NEGATIVE(options->dilation_h_factor); + // Mappings from TensorflowLite filter tensors to the ArmNN filter tensors (ArmNN weights have to be [M, I, H, W]) + PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W] + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); - armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1]); + armnn::TensorInfo filterTensorInfo = ToTensorInfo(inputs[1], permutationVector); // Assuming input is NHWC unsigned int inputHeight = inputTensorInfo.GetShape()[1]; @@ -922,9 +930,6 @@ void TfLiteParser::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorInd inputTensorInfo.GetShape()[3], filterTensorInfo.GetShape()[3] / inputTensorInfo.GetShape()[3] }); - // Mappings from TensorflowLite filter tensors to the ArmNN filter tensors (ArmNN weights have to be [M, I, H, W]) - PermutationVector permutationVector{ 2, 3, 1, 0 }; // [H, W, I, M] -> [M, I, H, W] - CalcPadding(inputHeight, filterHeight, desc.m_StrideY, desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, options->padding); CalcPadding(inputWidth, filterWidth, desc.m_StrideX, |