aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/TosaSerialize.cpp17
1 files changed, 9 insertions, 8 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 547da8c..e69fcba 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1559,17 +1559,18 @@ TosaSerializationBlockBuilder::BuildTosaSerializationTensor(
return nullptr;
}
- mlir::RankedTensorType tensor =
- val.getType().dyn_cast<mlir::RankedTensorType>();
- if (!tensor) {
- llvm::errs() << "TOSA serialization, attempt to build an "
- "non-RankedTensorType Tensor\n";
+ auto ttype = val.getType().dyn_cast<mlir::TensorType>();
+ if (!ttype) {
+ llvm::errs() << "TOSA serialization, supplied value is not of TensorType\n";
return nullptr;
}
- std::vector<int32_t> shape(tensor.getShape().begin(),
- tensor.getShape().end());
- DType type = Type2DType(tensor.getElementType());
+ auto ranked = val.getType().dyn_cast<mlir::RankedTensorType>();
+ std::vector<int32_t> shape =
+ ttype.hasRank() ? std::vector<int32_t>(ranked.getShape().begin(),
+ ranked.getShape().end())
+ : std::vector<int32_t>();
+ DType type = Type2DType(ttype.getElementType());
ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>());
return ts;