From d00ad916b72a53eb2488d82899ec7f033294d959 Mon Sep 17 00:00:00 2001 From: Derek Lamberti Date: Wed, 22 Jan 2020 15:55:16 +0000 Subject: IVGCVSW-4370 Remove use of deprecated per-axis type !armnn:2620 Signed-off-by: Derek Lamberti Change-Id: I8d71bac981a7576c7f51783833f76151495c62c0 --- Utils.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/Utils.cpp b/Utils.cpp index cdebfaed..c01604c3 100644 --- a/Utils.cpp +++ b/Utils.cpp @@ -9,6 +9,8 @@ #include +#include + #include #include @@ -44,7 +46,7 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void case armnn::DataType::Float16: case armnn::DataType::Float32: case armnn::DataType::QAsymmU8: - case armnn::DataType::QuantizedSymm8PerAxis: + case armnn::DataType::QSymmS8: SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings); break; default: @@ -105,6 +107,7 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_0::Operand& operand) armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand) { using namespace armnn; + bool perChannel = false; DataType type; switch (operand.type) @@ -115,12 +118,12 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand) case V1_2::OperandType::TENSOR_FLOAT16: type = armnn::DataType::Float16; break; - case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: - type = armnn::DataType::QuantizedSymm8PerAxis; - break; case V1_2::OperandType::TENSOR_QUANT8_ASYMM: type = armnn::DataType::QAsymmU8; break; + case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: + perChannel=true; + ARMNN_FALLTHROUGH; case V1_2::OperandType::TENSOR_QUANT8_SYMM: type = armnn::DataType::QSymmS8; break; @@ -135,7 +138,7 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand) } TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type); - if (type == DataType::QuantizedSymm8PerAxis) + if (perChannel) { // ExtraParams is expected to be of type channelQuant BOOST_ASSERT(operand.extraParams.getDiscriminator() == -- cgit v1.2.1