aboutsummaryrefslogtreecommitdiff
path: root/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'Utils.cpp')
-rw-r--r--Utils.cpp13
1 files changed, 8 insertions, 5 deletions
diff --git a/Utils.cpp b/Utils.cpp
index cdebfaed..c01604c3 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -9,6 +9,8 @@
#include <armnnUtils/Permute.hpp>
+#include <armnn/Utils.hpp>
+
#include <cassert>
#include <cinttypes>
@@ -44,7 +46,7 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void
case armnn::DataType::Float16:
case armnn::DataType::Float32:
case armnn::DataType::QAsymmU8:
- case armnn::DataType::QuantizedSymm8PerAxis:
+ case armnn::DataType::QSymmS8:
SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings);
break;
default:
@@ -105,6 +107,7 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_0::Operand& operand)
armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand)
{
using namespace armnn;
+ bool perChannel = false;
DataType type;
switch (operand.type)
@@ -115,12 +118,12 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand)
case V1_2::OperandType::TENSOR_FLOAT16:
type = armnn::DataType::Float16;
break;
- case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
- type = armnn::DataType::QuantizedSymm8PerAxis;
- break;
case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
type = armnn::DataType::QAsymmU8;
break;
+ case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
+ perChannel=true;
+ ARMNN_FALLTHROUGH;
case V1_2::OperandType::TENSOR_QUANT8_SYMM:
type = armnn::DataType::QSymmS8;
break;
@@ -135,7 +138,7 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand)
}
TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type);
- if (type == DataType::QuantizedSymm8PerAxis)
+ if (perChannel)
{
// ExtraParams is expected to be of type channelQuant
BOOST_ASSERT(operand.extraParams.getDiscriminator() ==