aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Utils.cpp9
1 files changed, 5 insertions, 4 deletions
diff --git a/Utils.cpp b/Utils.cpp
index d94a9377..8a2812ad 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -172,6 +172,7 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand)
{
using namespace armnn;
bool perChannel = false;
+ bool isScalar = false;
DataType type;
switch (operand.type)
@@ -202,6 +203,7 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand)
break;
case V1_3::OperandType::INT32:
type = armnn::DataType::Signed32;
+ isScalar = true;
break;
case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
type = armnn::DataType::QAsymmS8;
@@ -211,14 +213,13 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand)
}
TensorInfo ret;
- // 0 dimensional tensors will be flagged as scalars
- if ( operand.dimensions.size() != 0)
+ if (isScalar)
{
- ret = TensorInfo(operand.dimensions.size(), operand.dimensions.data(), type);
+ ret = TensorInfo(TensorShape(armnn::Dimensionality::Scalar), type);
}
else
{
- ret = TensorInfo(TensorShape(armnn::Dimensionality::Scalar), type);
+ ret = TensorInfo(operand.dimensions.size(), operand.dimensions.data(), type);
}
if (perChannel)