aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-08-21 23:00:40 +0000
committerTai Ly <tai.ly@arm.com>2023-08-28 21:09:07 +0000
commit7566d1235cb646e46531c2eb34757cb4b3efa933 (patch)
tree670dc54644ab8e3c38ab9359604efc4ec171ad17
parentea49f62f7ab81750f19bef011683164fe9bd4080 (diff)
downloadtosa_mlir_translator-7566d1235cb646e46531c2eb34757cb4b3efa933.tar.gz
[tosa_mlir_translator] Support dynamic tensors
This adds serialization and deserialization support for: - unranked tensors (eg, *xi32) and - tensors with dynamic shapes (eg, ?x?xi32) Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ib2943333d8e3a199cf8a909b6db7197150666700
-rw-r--r--src/TosaDeserialize.cpp23
-rw-r--r--src/TosaSerialize.cpp24
m---------third_party/serialization_lib0
3 files changed, 38 insertions, 9 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 9e646c7..79f0c78 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -95,8 +95,15 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
<< "\n";
return mlir::failure();
}
- llvm::SmallVector<int64_t> shape(ts->GetShape().begin(),
- ts->GetShape().end());
+ llvm::SmallVector<int64_t> shape;
+ for (auto dim : ts->GetShape()) {
+ if (dim > 0) {
+ shape.push_back(dim);
+ } else {
+ // dynamic dim
+ shape.push_back(mlir::ShapedType::kDynamic);
+ }
+ }
type = mlir::RankedTensorType::get(llvm::makeArrayRef(shape), element_type);
return mlir::success();
}
@@ -1243,6 +1250,7 @@ private:
TosaMlirRegionBuilder *region_builder;
mlir::Block *block;
std::unordered_map<std::string, mlir::RankedTensorType> tensor_type_map;
+ std::unordered_set<std::string> unranked_tensors;
};
TosaSerializationHandler *TosaMlirOperatorBuilder::GetTsh() const {
@@ -1418,6 +1426,10 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
}
const auto &ts_name = ts->GetName();
tensor_type_map[ts_name] = type;
+ if (ts->GetIsUnranked()) {
+ assert(ts->GetShape().empty()); // unranked tensors should have shape = {}
+ unranked_tensors.insert(ts_name);
+ }
}
// Update operator_queue with operators whose inputs are all built
@@ -1443,7 +1455,12 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
// Initialize tensor_map/operator_queue based on block input arguments
for (const std::string &block_input_name : ser_block->GetInputs()) {
- auto type = tensor_type_map[block_input_name];
+ mlir::Type type = tensor_type_map[block_input_name];
+ if (unranked_tensors.count(block_input_name)) {
+ // recast type as unranked tensor type
+ auto element_type = type.cast<mlir::RankedTensorType>().getElementType();
+ type = mlir::UnrankedTensorType::get(element_type);
+ }
auto input_value = block->addArgument(type, loc);
if (tensor_map.count(block_input_name)) {
llvm::errs() << "ERROR: block input tensor " << block_input_name
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 741597b..f74df1d 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1650,13 +1650,25 @@ TosaSerializationBlockBuilder::BuildTosaSerializationTensor(
return nullptr;
}
- auto ranked = val.getType().dyn_cast<mlir::RankedTensorType>();
- std::vector<int32_t> shape =
- ttype.hasRank() ? std::vector<int32_t>(ranked.getShape().begin(),
- ranked.getShape().end())
- : std::vector<int32_t>();
+ const bool is_unranked = !ttype.hasRank();
+ std::vector<int32_t> shape;
+ if (!is_unranked) {
+ auto shaped = val.getType().dyn_cast<mlir::ShapedType>();
+ assert(shaped);
+ for (int idx = 0; idx < ttype.getRank(); idx++) {
+ if (shaped.isDynamicDim(idx)) {
+ shape.push_back(0); // size of 0 represents dynamic dimension
+ } else {
+ auto dim = shaped.getDimSize(idx);
+ assert(dim > 0);
+ shape.push_back(dim);
+ }
+ }
+ }
+
DType type = Type2DType(ttype.getElementType());
- ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>());
+ ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(),
+ /* variable = */ false, is_unranked);
return ts;
}
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject 1adc5d05d9fd21591790678a3f1cdaa4c5b347c
+Subproject c6939a4d269968a34b0ae0aa579f0f0736aaecc