diff options
-rw-r--r-- | src/TosaDeserialize.cpp | 23 | ||||
-rw-r--r-- | src/TosaSerialize.cpp | 24 | ||||
m--------- | third_party/serialization_lib | 0 |
3 files changed, 38 insertions, 9 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 9e646c7..79f0c78 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -95,8 +95,15 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, << "\n"; return mlir::failure(); } - llvm::SmallVector<int64_t> shape(ts->GetShape().begin(), - ts->GetShape().end()); + llvm::SmallVector<int64_t> shape; + for (auto dim : ts->GetShape()) { + if (dim > 0) { + shape.push_back(dim); + } else { + // dynamic dim + shape.push_back(mlir::ShapedType::kDynamic); + } + } type = mlir::RankedTensorType::get(llvm::makeArrayRef(shape), element_type); return mlir::success(); } @@ -1243,6 +1250,7 @@ private: TosaMlirRegionBuilder *region_builder; mlir::Block *block; std::unordered_map<std::string, mlir::RankedTensorType> tensor_type_map; + std::unordered_set<std::string> unranked_tensors; }; TosaSerializationHandler *TosaMlirOperatorBuilder::GetTsh() const { @@ -1418,6 +1426,10 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( } const auto &ts_name = ts->GetName(); tensor_type_map[ts_name] = type; + if (ts->GetIsUnranked()) { + assert(ts->GetShape().empty()); // unranked tensors should have shape = {} + unranked_tensors.insert(ts_name); + } } // Update operator_queue with operators whose inputs are all built @@ -1443,7 +1455,12 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( // Initialize tensor_map/operator_queue based on block input arguments for (const std::string &block_input_name : ser_block->GetInputs()) { - auto type = tensor_type_map[block_input_name]; + mlir::Type type = tensor_type_map[block_input_name]; + if (unranked_tensors.count(block_input_name)) { + // recast type as unranked tensor type + auto element_type = type.cast<mlir::RankedTensorType>().getElementType(); + type = mlir::UnrankedTensorType::get(element_type); + } auto input_value = block->addArgument(type, loc); if (tensor_map.count(block_input_name)) { llvm::errs() << "ERROR: block input tensor " << block_input_name diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 741597b..f74df1d 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1650,13 +1650,25 @@ TosaSerializationBlockBuilder::BuildTosaSerializationTensor( return nullptr; } - auto ranked = val.getType().dyn_cast<mlir::RankedTensorType>(); - std::vector<int32_t> shape = - ttype.hasRank() ? std::vector<int32_t>(ranked.getShape().begin(), - ranked.getShape().end()) - : std::vector<int32_t>(); + const bool is_unranked = !ttype.hasRank(); + std::vector<int32_t> shape; + if (!is_unranked) { + auto shaped = val.getType().dyn_cast<mlir::ShapedType>(); + assert(shaped); + for (int idx = 0; idx < ttype.getRank(); idx++) { + if (shaped.isDynamicDim(idx)) { + shape.push_back(0); // size of 0 represents dynamic dimension + } else { + auto dim = shaped.getDimSize(idx); + assert(dim > 0); + shape.push_back(dim); + } + } + } + DType type = Type2DType(ttype.getElementType()); - ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>()); + ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(), + /* variable = */ false, is_unranked); return ts; } diff --git a/third_party/serialization_lib b/third_party/serialization_lib -Subproject 1adc5d05d9fd21591790678a3f1cdaa4c5b347c +Subproject c6939a4d269968a34b0ae0aa579f0f0736aaecc |