aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKeith Davis <keith.davis@arm.com>2020-01-22 11:57:54 +0000
committerKeith Davis Arm <keith.davis@arm.com>2020-01-28 17:15:32 +0000
commitd305e1a203077bdbf2e3955abd252904127675a4 (patch)
tree6c375c3ad2f0b62ca70b0914845d9fd4279f10d6
parentb0efc60fa5740b34f1896a3c3e979f4dfd44fa2e (diff)
downloadarmnn-d305e1a203077bdbf2e3955abd252904127675a4.tar.gz
IVGCVSW-4335 Add support for per-channel QSymm8 to TfLite parser
Signed-off-by: Keith Davis <keith.davis@arm.com> Change-Id: I52f777f56138a27655a821aff376ecd0d3d23511
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp88
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp2
2 files changed, 69 insertions, 21 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 17c0781740..d3eed9cfb1 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -336,40 +336,70 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::
location.AsString()));
}
}
+ std::vector<unsigned int> safeShape = shapes;
+ if (safeShape.size() == 0)
+ {
+ safeShape.push_back(1);
+ }
float quantizationScale = 0.0f;
int32_t quantizationOffset = 0;
if (tensorPtr->quantization.get())
{
- CHECK_VALID_SIZE(tensorPtr->quantization->scale.size(), 0, 1);
- CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1);
-
- if (tensorPtr->quantization->scale.size() == 1)
+ if (tensorPtr->quantization->scale.size() <= 1)
{
- quantizationScale = tensorPtr->quantization->scale[0];
+ CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1);
+ CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1);
+
+ if (tensorPtr->quantization->scale.size() == 1)
+ {
+ quantizationScale = tensorPtr->quantization->scale[0];
+ }
+ if (tensorPtr->quantization->zero_point.size() == 1)
+ {
+ // NOTE: we lose precision here when converting from 64 bit to 32
+ // but this is what we support at the monent in ArmNN
+ quantizationOffset = boost::numeric_cast<int32_t>(tensorPtr->quantization->zero_point[0]);
+ }
+
+ armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
+ safeShape.data(),
+ type,
+ quantizationScale,
+ quantizationOffset);
+
+ return result;
}
- if (tensorPtr->quantization->zero_point.size() == 1)
+ else
{
- // NOTE: we lose precision here when converting from 64 bit to 32
- // but this is what we support at the monent in ArmNN
- quantizationOffset = static_cast<int32_t>(tensorPtr->quantization->zero_point[0]);
+ std::vector<float> quantizationScales;
+ std::vector<int32_t> quantizationOffsets;
+
+ // Scale
+ std::copy(tensorPtr->quantization->scale.begin(),
+ tensorPtr->quantization->scale.end(),
+ std::back_inserter(quantizationScales));
+
+ // QSymm Per-axis
+ armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
+ safeShape.data(),
+ type,
+ quantizationScales,
+ boost::numeric_cast<unsigned int>(tensorPtr->quantization->quantized_dimension));
+
+ return result;
}
}
-
- std::vector<unsigned int> safeShape = shapes;
- if (safeShape.size() == 0)
+ else
{
- safeShape.push_back(1);
+ armnn::TensorInfo result(boost::numeric_cast<unsigned int>(safeShape.size()),
+ safeShape.data(),
+ type,
+ quantizationScale,
+ quantizationOffset);
+ return result;
}
-
- // two statements (on purpose) for easier debugging:
- armnn::TensorInfo result(static_cast<unsigned int>(safeShape.size()),
- safeShape.data(),
- type,
- quantizationScale,
- quantizationOffset);
- return result;
}
armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr)
@@ -2848,6 +2878,11 @@ TfLiteParser::CreateConstTensor(TensorRawPtr tensorPtr,
tensorPtr,
tensorInfo,
permutationVector);
+ case armnn::DataType::QSymmS8:
+ return CreateConstTensorAndStoreData<int8_t>(bufferPtr,
+ tensorPtr,
+ tensorInfo,
+ permutationVector);
case armnn::DataType::Signed32:
return CreateConstTensorAndStoreData<int32_t>(bufferPtr,
tensorPtr,
@@ -2977,6 +3012,7 @@ void ITfLiteParser::Destroy(ITfLiteParser* parser)
TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<float[]> && data)
: m_FloatData(std::move(data))
, m_Uint8Data(nullptr)
+, m_Int8Data(nullptr)
, m_Int32Data(nullptr)
{
}
@@ -2984,6 +3020,15 @@ TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<float[]
TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<uint8_t[]> && data)
: m_FloatData(nullptr)
, m_Uint8Data(std::move(data))
+, m_Int8Data(nullptr)
+, m_Int32Data(nullptr)
+{
+}
+
+TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<int8_t[]> && data)
+: m_FloatData(nullptr)
+, m_Uint8Data(nullptr)
+, m_Int8Data(std::move(data))
, m_Int32Data(nullptr)
{
}
@@ -2991,6 +3036,7 @@ TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<uint8_t
TfLiteParser::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<int32_t[]> && data)
: m_FloatData(nullptr)
, m_Uint8Data(nullptr)
+, m_Int8Data(nullptr)
, m_Int32Data(std::move(data))
{
}
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 42ea1a0372..a34e35fad1 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -166,12 +166,14 @@ private:
// Convenience constructors
SupportedDataStorage(std::unique_ptr<float[]>&& data);
SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data);
+ SupportedDataStorage(std::unique_ptr<int8_t[]>&& data);
SupportedDataStorage(std::unique_ptr<int32_t[]>&& data);
private:
// Pointers to the data buffers
std::unique_ptr<float[]> m_FloatData;
std::unique_ptr<uint8_t[]> m_Uint8Data;
+ std::unique_ptr<int8_t[]> m_Int8Data;
std::unique_ptr<int32_t[]> m_Int32Data;
};