diff options
-rw-r--r-- | src/TosaSerialize.cpp | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 547da8c..e69fcba 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1559,17 +1559,18 @@ TosaSerializationBlockBuilder::BuildTosaSerializationTensor( return nullptr; } - mlir::RankedTensorType tensor = - val.getType().dyn_cast<mlir::RankedTensorType>(); - if (!tensor) { - llvm::errs() << "TOSA serialization, attempt to build an " - "non-RankedTensorType Tensor\n"; + auto ttype = val.getType().dyn_cast<mlir::TensorType>(); + if (!ttype) { + llvm::errs() << "TOSA serialization, supplied value is not of TensorType\n"; return nullptr; } - std::vector<int32_t> shape(tensor.getShape().begin(), - tensor.getShape().end()); - DType type = Type2DType(tensor.getElementType()); + auto ranked = val.getType().dyn_cast<mlir::RankedTensorType>(); + std::vector<int32_t> shape = + ttype.hasRank() ? std::vector<int32_t>(ranked.getShape().begin(), + ranked.getShape().end()) + : std::vector<int32_t>(); + DType type = Type2DType(ttype.getElementType()); ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>()); return ts; |