aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2023-01-03 16:29:44 +0000
committermike.kelly <mike.kelly@arm.com>2023-01-05 11:48:13 +0000
commit0506ef0a099f5ba564af5e110e6857a68f462080 (patch)
tree2ff1a15e435c41916a7f93f14766456759dd20b1 /src/armnnTfLiteParser
parent8b4a483e0e2fee508c23be2248ba0409789f1a74 (diff)
downloadarmnn-0506ef0a099f5ba564af5e110e6857a68f462080.tar.gz
GitHub #543 Problem Parsing Mixed-Precision Model
* Fixed bug when converting Constants with Per-Axis Quantization Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: Ifbea23e60483746ec987da491dae96e74cb33af4
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp104
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp8
-rw-r--r--src/armnnTfLiteParser/test/Conv2D.cpp2
3 files changed, 59 insertions, 55 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 0484c6f478..191cfd2b48 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -316,6 +316,14 @@ std::vector<unsigned int> GetUIntBuffer(armnn::TensorInfo info,
::memcpy(uint64Buffer.data(), bufferPtr->data.data(), bufferPtr->data.size());
buffer.assign(std::begin(uint64Buffer), std::end(uint64Buffer));
}
+ else
+ {
+ CheckLocation location = CHECK_LOCATION();
+ throw ParseException(
+ fmt::format("Unsupported data type for uint buffer {}, only Signed 32 or Signed 64 are supported. {}",
+ GetDataTypeName(info.GetDataType()),
+ location.AsString()));
+ }
return buffer;
}
@@ -911,42 +919,16 @@ INetworkPtr TfLiteParserImpl::CreateNetworkFromModel()
return std::move(m_Network);
}
-std::unique_ptr<float[]> AsFloatArray(TfLiteParserImpl::BufferRawPtr bufferPtr,
- const TensorInfo& tensorInfo)
+bool TfLiteParserImpl::ShouldConstantTensorBeConverted(TfLiteParserImpl::TensorRawPtr tensorPtr,
+ armnn::DataType inputDataType,
+ armnn::DataType tensorDataType)
{
- if (tensorInfo.GetDataType() == DataType::QAsymmS8 || tensorInfo.GetDataType() == DataType::QSymmS8 ||
- tensorInfo.GetDataType() == DataType::QAsymmU8)
- {
- std::unique_ptr<float[]> buffer(new float[tensorInfo.GetNumElements()]);
-
- if (tensorInfo.HasPerAxisQuantization())
- {
- unsigned int axis = tensorInfo.GetQuantizationDim().value();
- auto axisDimensionality = tensorInfo.GetShape()[axis];
- auto axisFactor = armnnUtils::GetNumElementsAfter(tensorInfo.GetShape(), axis);
-
- for (unsigned int i = 0; i < tensorInfo.GetNumDimensions(); ++i)
- {
- unsigned int axisIndex = (i / axisFactor) % axisDimensionality;
- buffer[i] = Dequantize<int8_t>(bufferPtr->data[i], tensorInfo.GetQuantizationScales()[axisIndex],
- tensorInfo.GetQuantizationOffset());
- }
- }
- else
- {
- for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
- {
- buffer[i] = Dequantize<int8_t>(bufferPtr->data[i], tensorInfo.GetQuantizationScale(),
- tensorInfo.GetQuantizationOffset());
- }
- }
- return buffer;
- }
- throw ParseException(
- fmt::format("Unsupported input/weights combination: Input {} not supported with Weights {}",
- GetDataTypeName(DataType::Float32),
- GetDataTypeName(tensorInfo.GetDataType()),
- CHECK_LOCATION().AsString()));
+ return (TfLiteParserImpl::IsConstTensor(tensorPtr) && inputDataType == DataType::Float32 &&
+ (tensorDataType == DataType::QAsymmU8 ||
+ tensorDataType == DataType::QAsymmS8 ||
+ tensorDataType == DataType::QSymmS8 ||
+ tensorDataType == DataType::Signed32 ||
+ tensorDataType == DataType::Signed64));
}
void TfLiteParserImpl::RegisterProducerOfTensor(size_t subgraphIndex,
@@ -1136,9 +1118,7 @@ void TfLiteParserImpl::ParseConv2D(size_t subgraphIndex, size_t operatorIndex)
auto layerName = fmt::format("Conv2D:{}:{}", subgraphIndex, operatorIndex);
armnn::IConnectableLayer* layer = m_Network->AddConvolution2dLayer(desc, layerName.c_str());
- if (IsConstTensor(inputs[1]) && inputTensorInfo.GetDataType() == DataType::Float32 &&
- (filterTensorInfo.GetDataType() == DataType::QAsymmU8 ||
- filterTensorInfo.GetDataType() == DataType::QAsymmS8))
+ if (ShouldConstantTensorBeConverted(inputs[1], inputTensorInfo.GetDataType(), filterTensorInfo.GetDataType()))
{
m_ConstantsToDequantize.emplace_back(inputs[1]->buffer);
}
@@ -1150,9 +1130,7 @@ void TfLiteParserImpl::ParseConv2D(size_t subgraphIndex, size_t operatorIndex)
// Add the biases input to the registration list, a constant layer will be added by SetupConstantLayers.
tensorIndexesToRegister.emplace_back(inputTensorIndexes[2]);
- if (IsConstTensor(inputs[2]) && inputTensorInfo.GetDataType() == DataType::Float32 &&
- (filterTensorInfo.GetDataType() == DataType::QAsymmU8 ||
- filterTensorInfo.GetDataType() == DataType::QAsymmS8))
+ if (ShouldConstantTensorBeConverted(inputs[2], inputTensorInfo.GetDataType(), biasTensorInfo.GetDataType()))
{
m_ConstantsToDequantize.emplace_back(inputs[2]->buffer);
}
@@ -3112,9 +3090,7 @@ void TfLiteParserImpl::ParseFullyConnected(size_t subgraphIndex, size_t operator
// Add the weights input to the registration list, constant layers will be added by SetupConstantLayers if constant.
tensorIndexesToRegister.emplace_back(inputTensorIndexes[1]);
- if (desc.m_ConstantWeights && inputTensorInfo.GetDataType() == DataType::Float32 &&
- (filterTensorInfo.GetDataType() == DataType::QAsymmU8 ||
- filterTensorInfo.GetDataType() == DataType::QAsymmS8))
+ if (ShouldConstantTensorBeConverted(inputs[1], inputTensorInfo.GetDataType(), filterTensorInfo.GetDataType()))
{
m_ConstantsToDequantize.emplace_back(inputs[1]->buffer);
}
@@ -3127,9 +3103,7 @@ void TfLiteParserImpl::ParseFullyConnected(size_t subgraphIndex, size_t operator
// Add the biases input to the registration list, constant layer will be added by SetupConstantLayers.
tensorIndexesToRegister.emplace_back(inputTensorIndexes[2]);
- if (desc.m_ConstantWeights && inputTensorInfo.GetDataType() == DataType::Float32 &&
- (biasTensorInfo.GetDataType() == DataType::QAsymmU8 ||
- biasTensorInfo.GetDataType() == DataType::QAsymmS8))
+ if (ShouldConstantTensorBeConverted(inputs[2], inputTensorInfo.GetDataType(), biasTensorInfo.GetDataType()))
{
m_ConstantsToDequantize.emplace_back(inputs[2]->buffer);
}
@@ -4925,11 +4899,22 @@ TfLiteParserImpl::CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
// Make sure isConstant flag is set.
tensorInfo.SetConstant();
- if (inputDataType == DataType::Float32 && tensorInfo.GetDataType() != DataType::Float32)
+ if (inputDataType == DataType::Float32 && tensorInfo.GetDataType() != DataType::Float32)
{
- TensorInfo constTensorInfo(tensorInfo.GetShape(), DataType::Float32, 0.0f, 0, true);
- std::unique_ptr<float[]> data = AsFloatArray(bufferPtr, tensorInfo);
- return std::make_pair(ConstTensor(constTensorInfo, data.get()), std::move(data));
+ try
+ {
+ TensorInfo constTensorInfo(tensorInfo.GetShape(), DataType::Float32, 0.0f, 0, true);
+ std::unique_ptr<float[]> data = armnnUtils::ToFloatArray(bufferPtr->data, tensorInfo);
+ return std::make_pair(ConstTensor(constTensorInfo, data.get()), std::move(data));
+ }
+ catch (armnn::InvalidArgumentException)
+ {
+ throw ParseException(
+ fmt::format("Unsupported input/weights combination: Input {} not supported with Weights {}",
+ GetDataTypeName(DataType::Float32),
+ GetDataTypeName(tensorInfo.GetDataType()),
+ CHECK_LOCATION().AsString()));
+ }
}
else
{
@@ -4950,9 +4935,20 @@ TfLiteParserImpl::CreateConstTensorPtr(TensorRawPtr tensorPtr, armnn::TensorInfo
if (inputTensorInfo.GetDataType() == DataType::Float32 && tensorInfo.GetDataType() != DataType::Float32)
{
- TensorInfo constTensorInfo(tensorInfo.GetShape(), DataType::Float32, 0.0f, 0, true);
- std::unique_ptr<float[]> data = AsFloatArray(bufferPtr, tensorInfo);
- return std::make_pair(new ConstTensor(constTensorInfo, data.get()), std::move(data));
+ try
+ {
+ TensorInfo constTensorInfo(tensorInfo.GetShape(), DataType::Float32, 0.0f, 0, true);
+ std::unique_ptr<float[]> data = armnnUtils::ToFloatArray(bufferPtr->data, tensorInfo);
+ return std::make_pair(new ConstTensor(constTensorInfo, data.get()), std::move(data));
+ }
+ catch (armnn::InvalidArgumentException)
+ {
+ throw ParseException(
+ fmt::format("Unsupported input/weights combination: Input {} not supported with Weights {}",
+ GetDataTypeName(DataType::Float32),
+ GetDataTypeName(tensorInfo.GetDataType()),
+ CHECK_LOCATION().AsString()));
+ }
}
else
{
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index f8ddc55649..7eb6c48501 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -242,7 +242,13 @@ private:
};
bool ShouldConstantTensorBeCreated(unsigned int tensorIndex);
+
bool IsConstTensor(TensorRawPtr tensorPtr);
+
+ bool ShouldConstantTensorBeConverted(TfLiteParserImpl::TensorRawPtr tensorPtr,
+ armnn::DataType inputDataType,
+ armnn::DataType filterDataType);
+
armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
armnn::TensorInfo& tensorInfo);
@@ -250,6 +256,7 @@ private:
CreateConstTensorPermuted(TensorRawPtr tensorPtr,
armnn::TensorInfo& tensorInfo,
armnn::Optional<armnn::PermutationVector&> permutationVector);
+
std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
armnn::TensorInfo& tensorInfo,
@@ -261,6 +268,7 @@ private:
TfLiteParserImpl::TensorRawPtr tensorPtr,
armnn::TensorInfo& tensorInfo,
armnn::Optional<armnn::PermutationVector&> permutationVector);
+
std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>>
CreateConstTensorPtr(TensorRawPtr tensorPtr,
armnn::TensorInfo& inputTensorInfo);
diff --git a/src/armnnTfLiteParser/test/Conv2D.cpp b/src/armnnTfLiteParser/test/Conv2D.cpp
index 45c4a43519..334c102344 100644
--- a/src/armnnTfLiteParser/test/Conv2D.cpp
+++ b/src/armnnTfLiteParser/test/Conv2D.cpp
@@ -673,7 +673,7 @@ struct Conv2FloatWithInt8WeightsAndBiasesFixture : Conv2DWithBiasesFixture
"[ 1, 2, 2, 1 ]", // filterShape
"[ 2,1, 0,6 ]", // filterData
"[ 1 ]", // biasShape
- "[ 10, 0, 0, 0 ]", // biasData
+ "[ 10 ]", // biasData
"1", // stride w and h
"NONE", // activation
"1.0", // filterScale