aboutsummaryrefslogtreecommitdiff
path: root/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'Utils.cpp')
-rw-r--r--Utils.cpp28
1 files changed, 24 insertions, 4 deletions
diff --git a/Utils.cpp b/Utils.cpp
index 43b65ee3..246d6415 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -52,6 +52,9 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void
case armnn::DataType::QuantisedAsymm8:
SwizzleAndroidNn4dTensorToArmNn<uint8_t>(tensor.GetShape(), input, output, mappings);
break;
+ case armnn::DataType::QuantizedSymm8PerAxis:
+ SwizzleAndroidNn4dTensorToArmNn<int8_t>(tensor.GetShape(), input, output, mappings);
+ break;
default:
ALOGW("Unknown armnn::DataType for swizzling");
assert(0);
@@ -109,8 +112,9 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_0::Operand& operand)
armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand)
{
- armnn::DataType type;
+ using namespace armnn;
+ DataType type;
switch (operand.type)
{
case V1_2::OperandType::TENSOR_FLOAT32:
@@ -119,6 +123,9 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand)
case V1_2::OperandType::TENSOR_FLOAT16:
type = armnn::DataType::Float16;
break;
+ case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+ type = armnn::DataType::QuantizedSymm8PerAxis;
+ break;
case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
type = armnn::DataType::QuantisedAsymm8;
break;
@@ -132,10 +139,23 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand)
throw UnsupportedOperand<V1_2::OperandType>(operand.type);
}
- armnn::TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type);
+ TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type);
+ if (type == DataType::QuantizedSymm8PerAxis)
+ {
+ // ExtraParams is expected to be of type channelQuant
+ BOOST_ASSERT(operand.extraParams.getDiscriminator() ==
+ V1_2::Operand::ExtraParams::hidl_discriminator::channelQuant);
- ret.SetQuantizationScale(operand.scale);
- ret.SetQuantizationOffset(operand.zeroPoint);
+ auto perAxisQuantParams = operand.extraParams.channelQuant();
+
+ ret.SetQuantizationScales(perAxisQuantParams.scales);
+ ret.SetQuantizationDim(MakeOptional<unsigned int>(perAxisQuantParams.channelDim));
+ }
+ else
+ {
+ ret.SetQuantizationScale(operand.scale);
+ ret.SetQuantizationOffset(operand.zeroPoint);
+ }
return ret;
}