aboutsummaryrefslogtreecommitdiff
path: root/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'Utils.cpp')
-rw-r--r--Utils.cpp16
1 files changed, 14 insertions, 2 deletions
diff --git a/Utils.cpp b/Utils.cpp
index 6481c287..d94a9377 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -200,6 +200,9 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand)
case V1_3::OperandType::TENSOR_INT32:
type = armnn::DataType::Signed32;
break;
+ case V1_3::OperandType::INT32:
+ type = armnn::DataType::Signed32;
+ break;
case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
type = armnn::DataType::QAsymmS8;
break;
@@ -207,7 +210,17 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand)
throw UnsupportedOperand<V1_3::OperandType>(operand.type);
}
- TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type);
+ TensorInfo ret;
+ // 0 dimensional tensors will be flagged as scalars
+ if ( operand.dimensions.size() != 0)
+ {
+ ret = TensorInfo(operand.dimensions.size(), operand.dimensions.data(), type);
+ }
+ else
+ {
+ ret = TensorInfo(TensorShape(armnn::Dimensionality::Scalar), type);
+ }
+
if (perChannel)
{
// ExtraParams is expected to be of type channelQuant
@@ -224,7 +237,6 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand)
ret.SetQuantizationScale(operand.scale);
ret.SetQuantizationOffset(operand.zeroPoint);
}
-
return ret;
}