diff options
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r-- | src/armnn/QuantizerVisitor.cpp | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp index 7889f03c5b..0e9d22463f 100644 --- a/src/armnn/QuantizerVisitor.cpp +++ b/src/armnn/QuantizerVisitor.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -52,16 +52,20 @@ void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer* src IInputSlot& newInputSlot = quantizedLayer->GetInputSlot(i); IOutputSlot& newOutputSlot = prevQuantizedLayer->GetOutputSlot(slotIdx); newOutputSlot.Connect(newInputSlot); - - // Fetch the min/max ranges that were computed earlier - auto range = m_Ranges.GetRange(layerToFind.GetGuid(), slotIdx); - OffsetScalePair qParams = m_QuantizationScheme->ComputeScheme(range.first, range.second); - - // Set the quantization params TensorInfo info(outputSlot->GetTensorInfo()); - info.SetDataType(m_QuantizationScheme->GetDataType()); - info.SetQuantizationOffset(qParams.second); - info.SetQuantizationScale(qParams.first); + + // Only try to set quantization params on tensors that can be quantized + if (inputSlot->GetConnectedOutputSlot()->GetTensorInfo().GetDataType() != DataType::Boolean && + inputSlot->GetConnectedOutputSlot()->GetTensorInfo().GetDataType() != DataType::Signed32 && + inputSlot->GetConnectedOutputSlot()->GetTensorInfo().GetDataType() != DataType::Signed64) + { + // Fetch the min/max ranges that were computed earlier + auto range = m_Ranges.GetRange(layerToFind.GetGuid(), slotIdx); + OffsetScalePair qParams = m_QuantizationScheme->ComputeScheme(range.first, range.second); + info.SetDataType(m_QuantizationScheme->GetDataType()); + info.SetQuantizationOffset(qParams.second); + info.SetQuantizationScale(qParams.first); + } newOutputSlot.SetTensorInfo(info); } } |