diff options
-rw-r--r-- | Utils.cpp | 13 |
1 files changed, 8 insertions, 5 deletions
@@ -9,6 +9,8 @@ #include <armnnUtils/Permute.hpp> +#include <armnn/Utils.hpp> + #include <cassert> #include <cinttypes> @@ -44,7 +46,7 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void case armnn::DataType::Float16: case armnn::DataType::Float32: case armnn::DataType::QAsymmU8: - case armnn::DataType::QuantizedSymm8PerAxis: + case armnn::DataType::QSymmS8: SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings); break; default: @@ -105,6 +107,7 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_0::Operand& operand) armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand) { using namespace armnn; + bool perChannel = false; DataType type; switch (operand.type) @@ -115,12 +118,12 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand) case V1_2::OperandType::TENSOR_FLOAT16: type = armnn::DataType::Float16; break; - case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: - type = armnn::DataType::QuantizedSymm8PerAxis; - break; case V1_2::OperandType::TENSOR_QUANT8_ASYMM: type = armnn::DataType::QAsymmU8; break; + case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: + perChannel=true; + ARMNN_FALLTHROUGH; case V1_2::OperandType::TENSOR_QUANT8_SYMM: type = armnn::DataType::QSymmS8; break; @@ -135,7 +138,7 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand) } TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type); - if (type == DataType::QuantizedSymm8PerAxis) + if (perChannel) { // ExtraParams is expected to be of type channelQuant BOOST_ASSERT(operand.extraParams.getDiscriminator() == |