diff options
author | Keith Davis <keith.davis@arm.com> | 2020-01-22 11:57:54 +0000 |
---|---|---|
committer | Keith Davis Arm <keith.davis@arm.com> | 2020-01-28 17:15:32 +0000 |
commit | d305e1a203077bdbf2e3955abd252904127675a4 (patch) | |
tree | 6c375c3ad2f0b62ca70b0914845d9fd4279f10d6 /src/armnnTfLiteParser/TfLiteParser.cpp | |
parent | b0efc60fa5740b34f1896a3c3e979f4dfd44fa2e (diff) | |
download | armnn-d305e1a203077bdbf2e3955abd252904127675a4.tar.gz |
IVGCVSW-4335 Add support for per-channel QSymm8 to TfLite parser
Signed-off-by: Keith Davis <keith.davis@arm.com>
Change-Id: I52f777f56138a27655a821aff376ecd0d3d23511
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 88 |
1 files changed, 67 insertions, 21 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 17c0781740..d3eed9cfb1 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -336,40 +336,70 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std:: location.AsString())); } } + std::vector<unsigned int> safeShape = shapes; + if (safeShape.size() == 0) + { + safeShape.push_back(1); + } float quantizationScale = 0.0f; int32_t quantizationOffset = 0; if (tensorPtr->quantization.get()) { - CHECK_VALID_SIZE(tensorPtr->quantization->scale.size(), 0, 1); - CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1); - - if (tensorPtr->quantization->scale.size() == 1) + if (tensorPtr->quantization->scale.size() <= 1) { - quantizationScale = tensorPtr->quantization->scale[0]; + CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1); + CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1); + + if (tensorPtr->quantization->scale.size() == 1) + { + quantizationScale = tensorPtr->quantization->scale[0]; + } + if (tensorPtr->quantization->zero_point.size() == 1) + { + // NOTE: we lose precision here when converting from 64 bit to 32 + // but this is what we support at the monent in ArmNN + quantizationOffset = boost::numeric_cast<int32_t>(tensorPtr->quantization->zero_point[0]); + } + + armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()), + safeShape.data(), + type, + quantizationScale, + quantizationOffset); + + return result; } - if (tensorPtr->quantization->zero_point.size() == 1) + else { - // NOTE: we lose precision here when converting from 64 bit to 32 - // but this is what we support at the monent in ArmNN - quantizationOffset = static_cast<int32_t>(tensorPtr->quantization->zero_point[0]); + std::vector<float> quantizationScales; + std::vector<int32_t> quantizationOffsets; + + // Scale + std::copy(tensorPtr->quantization->scale.begin(), + tensorPtr->quantization->scale.end(), + std::back_inserter(quantizationScales)); + + // QSymm Per-axis + armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()), + safeShape.data(), + type, + quantizationScales, + boost::numeric_cast<unsigned int>(tensorPtr->quantization->quantized_dimension)); + + return result; } } - - std::vector<unsigned int> safeShape = shapes; - if (safeShape.size() == 0) + else { - safeShape.push_back(1); + armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()), + safeShape.data(), + type, + quantizationScale, + quantizationOffset); + return result; } - - // two statements (on purpose) for easier debugging: - armnn::TensorInfo result(static_cast<unsigned int>(safeShape.size()), - safeShape.data(), - type, - quantizationScale, - quantizationOffset); - return result; } armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr) @@ -2848,6 +2878,11 @@ TfLiteParser::CreateConstTensor(TensorRawPtr tensorPtr, tensorPtr, tensorInfo, permutationVector); + case armnn::DataType::QSymmS8: + return CreateConstTensorAndStoreData<int8_t>(bufferPtr, + tensorPtr, + tensorInfo, + permutationVector); case armnn::DataType::Signed32: return CreateConstTensorAndStoreData<int32_t>(bufferPtr, tensorPtr, @@ -2977,6 +3012,7 @@ void ITfLiteParser::Destroy(ITfLiteParser* parser) TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<float[]> && data) : m_FloatData(std::move(data)) , m_Uint8Data(nullptr) +, m_Int8Data(nullptr) , m_Int32Data(nullptr) { } @@ -2984,6 +3020,15 @@ TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<float[] TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<uint8_t[]> && data) : m_FloatData(nullptr) , m_Uint8Data(std::move(data)) +, m_Int8Data(nullptr) +, m_Int32Data(nullptr) +{ +} + +TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<int8_t[]> && data) +: m_FloatData(nullptr) +, m_Uint8Data(nullptr) +, m_Int8Data(std::move(data)) , m_Int32Data(nullptr) { } @@ -2991,6 +3036,7 @@ TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<uint8_t TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<int32_t[]> && data) : m_FloatData(nullptr) , m_Uint8Data(nullptr) +, m_Int8Data(nullptr) , m_Int32Data(std::move(data)) { } |