diff options
author | Mike Kelly <mike.kelly@arm.com> | 2023-01-03 16:29:44 +0000 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2023-01-05 11:48:13 +0000 |
commit | 0506ef0a099f5ba564af5e110e6857a68f462080 (patch) | |
tree | 2ff1a15e435c41916a7f93f14766456759dd20b1 /src/armnnTfLiteParser/TfLiteParser.cpp | |
parent | 8b4a483e0e2fee508c23be2248ba0409789f1a74 (diff) | |
download | armnn-0506ef0a099f5ba564af5e110e6857a68f462080.tar.gz |
GitHub #543 Problem Parsing Mixed-Precision Model
* Fixed bug when converting Constants with Per-Axis Quantization
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: Ifbea23e60483746ec987da491dae96e74cb33af4
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 104 |
1 files changed, 50 insertions, 54 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 0484c6f478..191cfd2b48 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -316,6 +316,14 @@ std::vector<unsigned int> GetUIntBuffer(armnn::TensorInfo info, ::memcpy(uint64Buffer.data(), bufferPtr->data.data(), bufferPtr->data.size()); buffer.assign(std::begin(uint64Buffer), std::end(uint64Buffer)); } + else + { + CheckLocation location = CHECK_LOCATION(); + throw ParseException( + fmt::format("Unsupported data type for uint buffer {}, only Signed 32 or Signed 64 are supported. {}", + GetDataTypeName(info.GetDataType()), + location.AsString())); + } return buffer; } @@ -911,42 +919,16 @@ INetworkPtr TfLiteParserImpl::CreateNetworkFromModel() return std::move(m_Network); } -std::unique_ptr<float[]> AsFloatArray(TfLiteParserImpl::BufferRawPtr bufferPtr, - const TensorInfo& tensorInfo) +bool TfLiteParserImpl::ShouldConstantTensorBeConverted(TfLiteParserImpl::TensorRawPtr tensorPtr, + armnn::DataType inputDataType, + armnn::DataType tensorDataType) { - if (tensorInfo.GetDataType() == DataType::QAsymmS8 || tensorInfo.GetDataType() == DataType::QSymmS8 || - tensorInfo.GetDataType() == DataType::QAsymmU8) - { - std::unique_ptr<float[]> buffer(new float[tensorInfo.GetNumElements()]); - - if (tensorInfo.HasPerAxisQuantization()) - { - unsigned int axis = tensorInfo.GetQuantizationDim().value(); - auto axisDimensionality = tensorInfo.GetShape()[axis]; - auto axisFactor = armnnUtils::GetNumElementsAfter(tensorInfo.GetShape(), axis); - - for (unsigned int i = 0; i < tensorInfo.GetNumDimensions(); ++i) - { - unsigned int axisIndex = (i / axisFactor) % axisDimensionality; - buffer[i] = Dequantize<int8_t>(bufferPtr->data[i], tensorInfo.GetQuantizationScales()[axisIndex], - tensorInfo.GetQuantizationOffset()); - } - } - else - { - for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i) - { - buffer[i] = Dequantize<int8_t>(bufferPtr->data[i], tensorInfo.GetQuantizationScale(), - tensorInfo.GetQuantizationOffset()); - } - } - return buffer; - } - throw ParseException( - fmt::format("Unsupported input/weights combination: Input {} not supported with Weights {}", - GetDataTypeName(DataType::Float32), - GetDataTypeName(tensorInfo.GetDataType()), - CHECK_LOCATION().AsString())); + return (TfLiteParserImpl::IsConstTensor(tensorPtr) && inputDataType == DataType::Float32 && + (tensorDataType == DataType::QAsymmU8 || + tensorDataType == DataType::QAsymmS8 || + tensorDataType == DataType::QSymmS8 || + tensorDataType == DataType::Signed32 || + tensorDataType == DataType::Signed64)); } void TfLiteParserImpl::RegisterProducerOfTensor(size_t subgraphIndex, @@ -1136,9 +1118,7 @@ void TfLiteParserImpl::ParseConv2D(size_t subgraphIndex, size_t operatorIndex) auto layerName = fmt::format("Conv2D:{}:{}", subgraphIndex, operatorIndex); armnn::IConnectableLayer* layer = m_Network->AddConvolution2dLayer(desc, layerName.c_str()); - if (IsConstTensor(inputs[1]) && inputTensorInfo.GetDataType() == DataType::Float32 && - (filterTensorInfo.GetDataType() == DataType::QAsymmU8 || - filterTensorInfo.GetDataType() == DataType::QAsymmS8)) + if (ShouldConstantTensorBeConverted(inputs[1], inputTensorInfo.GetDataType(), filterTensorInfo.GetDataType())) { m_ConstantsToDequantize.emplace_back(inputs[1]->buffer); } @@ -1150,9 +1130,7 @@ void TfLiteParserImpl::ParseConv2D(size_t subgraphIndex, size_t operatorIndex) // Add the biases input to the registration list, a constant layer will be added by SetupConstantLayers. tensorIndexesToRegister.emplace_back(inputTensorIndexes[2]); - if (IsConstTensor(inputs[2]) && inputTensorInfo.GetDataType() == DataType::Float32 && - (filterTensorInfo.GetDataType() == DataType::QAsymmU8 || - filterTensorInfo.GetDataType() == DataType::QAsymmS8)) + if (ShouldConstantTensorBeConverted(inputs[2], inputTensorInfo.GetDataType(), biasTensorInfo.GetDataType())) { m_ConstantsToDequantize.emplace_back(inputs[2]->buffer); } @@ -3112,9 +3090,7 @@ void TfLiteParserImpl::ParseFullyConnected(size_t subgraphIndex, size_t operator // Add the weights input to the registration list, constant layers will be added by SetupConstantLayers if constant. tensorIndexesToRegister.emplace_back(inputTensorIndexes[1]); - if (desc.m_ConstantWeights && inputTensorInfo.GetDataType() == DataType::Float32 && - (filterTensorInfo.GetDataType() == DataType::QAsymmU8 || - filterTensorInfo.GetDataType() == DataType::QAsymmS8)) + if (ShouldConstantTensorBeConverted(inputs[1], inputTensorInfo.GetDataType(), filterTensorInfo.GetDataType())) { m_ConstantsToDequantize.emplace_back(inputs[1]->buffer); } @@ -3127,9 +3103,7 @@ void TfLiteParserImpl::ParseFullyConnected(size_t subgraphIndex, size_t operator // Add the biases input to the registration list, constant layer will be added by SetupConstantLayers. tensorIndexesToRegister.emplace_back(inputTensorIndexes[2]); - if (desc.m_ConstantWeights && inputTensorInfo.GetDataType() == DataType::Float32 && - (biasTensorInfo.GetDataType() == DataType::QAsymmU8 || - biasTensorInfo.GetDataType() == DataType::QAsymmS8)) + if (ShouldConstantTensorBeConverted(inputs[2], inputTensorInfo.GetDataType(), biasTensorInfo.GetDataType())) { m_ConstantsToDequantize.emplace_back(inputs[2]->buffer); } @@ -4925,11 +4899,22 @@ TfLiteParserImpl::CreateConstTensorNonPermuted(TensorRawPtr tensorPtr, // Make sure isConstant flag is set. tensorInfo.SetConstant(); - if (inputDataType == DataType::Float32 && tensorInfo.GetDataType() != DataType::Float32) + if (inputDataType == DataType::Float32 && tensorInfo.GetDataType() != DataType::Float32) { - TensorInfo constTensorInfo(tensorInfo.GetShape(), DataType::Float32, 0.0f, 0, true); - std::unique_ptr<float[]> data = AsFloatArray(bufferPtr, tensorInfo); - return std::make_pair(ConstTensor(constTensorInfo, data.get()), std::move(data)); + try + { + TensorInfo constTensorInfo(tensorInfo.GetShape(), DataType::Float32, 0.0f, 0, true); + std::unique_ptr<float[]> data = armnnUtils::ToFloatArray(bufferPtr->data, tensorInfo); + return std::make_pair(ConstTensor(constTensorInfo, data.get()), std::move(data)); + } + catch (armnn::InvalidArgumentException) + { + throw ParseException( + fmt::format("Unsupported input/weights combination: Input {} not supported with Weights {}", + GetDataTypeName(DataType::Float32), + GetDataTypeName(tensorInfo.GetDataType()), + CHECK_LOCATION().AsString())); + } } else { @@ -4950,9 +4935,20 @@ TfLiteParserImpl::CreateConstTensorPtr(TensorRawPtr tensorPtr, armnn::TensorInfo if (inputTensorInfo.GetDataType() == DataType::Float32 && tensorInfo.GetDataType() != DataType::Float32) { - TensorInfo constTensorInfo(tensorInfo.GetShape(), DataType::Float32, 0.0f, 0, true); - std::unique_ptr<float[]> data = AsFloatArray(bufferPtr, tensorInfo); - return std::make_pair(new ConstTensor(constTensorInfo, data.get()), std::move(data)); + try + { + TensorInfo constTensorInfo(tensorInfo.GetShape(), DataType::Float32, 0.0f, 0, true); + std::unique_ptr<float[]> data = armnnUtils::ToFloatArray(bufferPtr->data, tensorInfo); + return std::make_pair(new ConstTensor(constTensorInfo, data.get()), std::move(data)); + } + catch (armnn::InvalidArgumentException) + { + throw ParseException( + fmt::format("Unsupported input/weights combination: Input {} not supported with Weights {}", + GetDataTypeName(DataType::Float32), + GetDataTypeName(tensorInfo.GetDataType()), + CHECK_LOCATION().AsString())); + } } else { |