diff options
-rw-r--r-- | Utils.cpp | 9 |
1 files changed, 5 insertions, 4 deletions
@@ -172,6 +172,7 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand) { using namespace armnn; bool perChannel = false; + bool isScalar = false; DataType type; switch (operand.type) @@ -202,6 +203,7 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand) break; case V1_3::OperandType::INT32: type = armnn::DataType::Signed32; + isScalar = true; break; case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED: type = armnn::DataType::QAsymmS8; @@ -211,14 +213,13 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand) } TensorInfo ret; - // 0 dimensional tensors will be flagged as scalars - if ( operand.dimensions.size() != 0) + if (isScalar) { - ret = TensorInfo(operand.dimensions.size(), operand.dimensions.data(), type); + ret = TensorInfo(TensorShape(armnn::Dimensionality::Scalar), type); } else { - ret = TensorInfo(TensorShape(armnn::Dimensionality::Scalar), type); + ret = TensorInfo(operand.dimensions.size(), operand.dimensions.data(), type); } if (perChannel) |