diff options
Diffstat (limited to 'Utils.cpp')
-rw-r--r-- | Utils.cpp | 28 |
1 files changed, 24 insertions, 4 deletions
@@ -52,6 +52,9 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void case armnn::DataType::QuantisedAsymm8: SwizzleAndroidNn4dTensorToArmNn<uint8_t>(tensor.GetShape(), input, output, mappings); break; + case armnn::DataType::QuantizedSymm8PerAxis: + SwizzleAndroidNn4dTensorToArmNn<int8_t>(tensor.GetShape(), input, output, mappings); + break; default: ALOGW("Unknown armnn::DataType for swizzling"); assert(0); @@ -109,8 +112,9 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_0::Operand& operand) armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand) { - armnn::DataType type; + using namespace armnn; + DataType type; switch (operand.type) { case V1_2::OperandType::TENSOR_FLOAT32: @@ -119,6 +123,9 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand) case V1_2::OperandType::TENSOR_FLOAT16: type = armnn::DataType::Float16; break; + case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: + type = armnn::DataType::QuantizedSymm8PerAxis; + break; case V1_2::OperandType::TENSOR_QUANT8_ASYMM: type = armnn::DataType::QuantisedAsymm8; break; @@ -132,10 +139,23 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand) throw UnsupportedOperand<V1_2::OperandType>(operand.type); } - armnn::TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type); + TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type); + if (type == DataType::QuantizedSymm8PerAxis) + { + // ExtraParams is expected to be of type channelQuant + BOOST_ASSERT(operand.extraParams.getDiscriminator() == + V1_2::Operand::ExtraParams::hidl_discriminator::channelQuant); - ret.SetQuantizationScale(operand.scale); - ret.SetQuantizationOffset(operand.zeroPoint); + auto perAxisQuantParams = operand.extraParams.channelQuant(); + + ret.SetQuantizationScales(perAxisQuantParams.scales); + ret.SetQuantizationDim(MakeOptional<unsigned int>(perAxisQuantParams.channelDim)); + } + else + { + ret.SetQuantizationScale(operand.scale); + ret.SetQuantizationOffset(operand.zeroPoint); + } return ret; } |