From d305e1a203077bdbf2e3955abd252904127675a4 Mon Sep 17 00:00:00 2001 From: Keith Davis Date: Wed, 22 Jan 2020 11:57:54 +0000 Subject: IVGCVSW-4335 Add support for per-channel QSymm8 to TfLite parser Signed-off-by: Keith Davis Change-Id: I52f777f56138a27655a821aff376ecd0d3d23511 --- src/armnnTfLiteParser/TfLiteParser.cpp | 88 ++++++++++++++++++++++++++-------- src/armnnTfLiteParser/TfLiteParser.hpp | 2 + 2 files changed, 69 insertions(+), 21 deletions(-) diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 17c0781740..d3eed9cfb1 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -336,40 +336,70 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std:: location.AsString())); } } + std::vector safeShape = shapes; + if (safeShape.size() == 0) + { + safeShape.push_back(1); + } float quantizationScale = 0.0f; int32_t quantizationOffset = 0; if (tensorPtr->quantization.get()) { - CHECK_VALID_SIZE(tensorPtr->quantization->scale.size(), 0, 1); - CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1); - - if (tensorPtr->quantization->scale.size() == 1) + if (tensorPtr->quantization->scale.size() <= 1) { - quantizationScale = tensorPtr->quantization->scale[0]; + CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1); + CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1); + + if (tensorPtr->quantization->scale.size() == 1) + { + quantizationScale = tensorPtr->quantization->scale[0]; + } + if (tensorPtr->quantization->zero_point.size() == 1) + { + // NOTE: we lose precision here when converting from 64 bit to 32 + // but this is what we support at the monent in ArmNN + quantizationOffset = boost::numeric_cast(tensorPtr->quantization->zero_point[0]); + } + + armnn::TensorInfo result(boost::numeric_cast(safeShape.size()), + safeShape.data(), + type, + quantizationScale, + quantizationOffset); + + return result; } - if (tensorPtr->quantization->zero_point.size() == 1) + else { - // NOTE: we lose precision here when converting from 64 bit to 32 - // but this is what we support at the monent in ArmNN - quantizationOffset = static_cast(tensorPtr->quantization->zero_point[0]); + std::vector quantizationScales; + std::vector quantizationOffsets; + + // Scale + std::copy(tensorPtr->quantization->scale.begin(), + tensorPtr->quantization->scale.end(), + std::back_inserter(quantizationScales)); + + // QSymm Per-axis + armnn::TensorInfo result(boost::numeric_cast(safeShape.size()), + safeShape.data(), + type, + quantizationScales, + boost::numeric_cast(tensorPtr->quantization->quantized_dimension)); + + return result; } } - - std::vector safeShape = shapes; - if (safeShape.size() == 0) + else { - safeShape.push_back(1); + armnn::TensorInfo result(boost::numeric_cast(safeShape.size()), + safeShape.data(), + type, + quantizationScale, + quantizationOffset); + return result; } - - // two statements (on purpose) for easier debugging: - armnn::TensorInfo result(static_cast(safeShape.size()), - safeShape.data(), - type, - quantizationScale, - quantizationOffset); - return result; } armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr) @@ -2848,6 +2878,11 @@ TfLiteParser::CreateConstTensor(TensorRawPtr tensorPtr, tensorPtr, tensorInfo, permutationVector); + case armnn::DataType::QSymmS8: + return CreateConstTensorAndStoreData(bufferPtr, + tensorPtr, + tensorInfo, + permutationVector); case armnn::DataType::Signed32: return CreateConstTensorAndStoreData(bufferPtr, tensorPtr, @@ -2977,6 +3012,7 @@ void ITfLiteParser::Destroy(ITfLiteParser* parser) TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr && data) : m_FloatData(std::move(data)) , m_Uint8Data(nullptr) +, m_Int8Data(nullptr) , m_Int32Data(nullptr) { } @@ -2984,6 +3020,15 @@ TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr && data) : m_FloatData(nullptr) , m_Uint8Data(std::move(data)) +, m_Int8Data(nullptr) +, m_Int32Data(nullptr) +{ +} + +TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr && data) +: m_FloatData(nullptr) +, m_Uint8Data(nullptr) +, m_Int8Data(std::move(data)) , m_Int32Data(nullptr) { } @@ -2991,6 +3036,7 @@ TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr && data) : m_FloatData(nullptr) , m_Uint8Data(nullptr) +, m_Int8Data(nullptr) , m_Int32Data(std::move(data)) { } diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp index 42ea1a0372..a34e35fad1 100644 --- a/src/armnnTfLiteParser/TfLiteParser.hpp +++ b/src/armnnTfLiteParser/TfLiteParser.hpp @@ -166,12 +166,14 @@ private: // Convenience constructors SupportedDataStorage(std::unique_ptr&& data); SupportedDataStorage(std::unique_ptr&& data); + SupportedDataStorage(std::unique_ptr&& data); SupportedDataStorage(std::unique_ptr&& data); private: // Pointers to the data buffers std::unique_ptr m_FloatData; std::unique_ptr m_Uint8Data; + std::unique_ptr m_Int8Data; std::unique_ptr m_Int32Data; }; -- cgit v1.2.1