aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-08-21 23:00:40 +0000
committerTai Ly <tai.ly@arm.com>2023-08-28 21:09:07 +0000
commit7566d1235cb646e46531c2eb34757cb4b3efa933 (patch)
tree670dc54644ab8e3c38ab9359604efc4ec171ad17 /src/TosaDeserialize.cpp
parentea49f62f7ab81750f19bef011683164fe9bd4080 (diff)
downloadtosa_mlir_translator-7566d1235cb646e46531c2eb34757cb4b3efa933.tar.gz
[tosa_mlir_translator] Support dynamic tensors
This adds serialization and deserialization support for: - unranked tensors (eg, *xi32) and - tensors with dynamic shapes (eg, ?x?xi32) Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ib2943333d8e3a199cf8a909b6db7197150666700
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp23
1 files changed, 20 insertions, 3 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 9e646c7..79f0c78 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -95,8 +95,15 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
<< "\n";
return mlir::failure();
}
- llvm::SmallVector<int64_t> shape(ts->GetShape().begin(),
- ts->GetShape().end());
+ llvm::SmallVector<int64_t> shape;
+ for (auto dim : ts->GetShape()) {
+ if (dim > 0) {
+ shape.push_back(dim);
+ } else {
+ // dynamic dim
+ shape.push_back(mlir::ShapedType::kDynamic);
+ }
+ }
type = mlir::RankedTensorType::get(llvm::makeArrayRef(shape), element_type);
return mlir::success();
}
@@ -1243,6 +1250,7 @@ private:
TosaMlirRegionBuilder *region_builder;
mlir::Block *block;
std::unordered_map<std::string, mlir::RankedTensorType> tensor_type_map;
+ std::unordered_set<std::string> unranked_tensors;
};
TosaSerializationHandler *TosaMlirOperatorBuilder::GetTsh() const {
@@ -1418,6 +1426,10 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
}
const auto &ts_name = ts->GetName();
tensor_type_map[ts_name] = type;
+ if (ts->GetIsUnranked()) {
+ assert(ts->GetShape().empty()); // unranked tensors should have shape = {}
+ unranked_tensors.insert(ts_name);
+ }
}
// Update operator_queue with operators whose inputs are all built
@@ -1443,7 +1455,12 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
// Initialize tensor_map/operator_queue based on block input arguments
for (const std::string &block_input_name : ser_block->GetInputs()) {
- auto type = tensor_type_map[block_input_name];
+ mlir::Type type = tensor_type_map[block_input_name];
+ if (unranked_tensors.count(block_input_name)) {
+ // recast type as unranked tensor type
+ auto element_type = type.cast<mlir::RankedTensorType>().getElementType();
+ type = mlir::UnrankedTensorType::get(element_type);
+ }
auto input_value = block->addArgument(type, loc);
if (tensor_map.count(block_input_name)) {
llvm::errs() << "ERROR: block input tensor " << block_input_name