aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp24
1 files changed, 18 insertions, 6 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 741597b..f74df1d 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1650,13 +1650,25 @@ TosaSerializationBlockBuilder::BuildTosaSerializationTensor(
return nullptr;
}
- auto ranked = val.getType().dyn_cast<mlir::RankedTensorType>();
- std::vector<int32_t> shape =
- ttype.hasRank() ? std::vector<int32_t>(ranked.getShape().begin(),
- ranked.getShape().end())
- : std::vector<int32_t>();
+ const bool is_unranked = !ttype.hasRank();
+ std::vector<int32_t> shape;
+ if (!is_unranked) {
+ auto shaped = val.getType().dyn_cast<mlir::ShapedType>();
+ assert(shaped);
+ for (int idx = 0; idx < ttype.getRank(); idx++) {
+ if (shaped.isDynamicDim(idx)) {
+ shape.push_back(0); // size of 0 represents dynamic dimension
+ } else {
+ auto dim = shaped.getDimSize(idx);
+ assert(dim > 0);
+ shape.push_back(dim);
+ }
+ }
+ }
+
DType type = Type2DType(ttype.getElementType());
- ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>());
+ ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(),
+ /* variable = */ false, is_unranked);
return ts;
}