From 8ffce6d63372ea93f3e060571137bce11d4735d8 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 16 Mar 2023 22:22:41 +0000 Subject: serialize/deserialize while/if ops using regions This changes serialization and deserialization of while and if ops to use regions instead of blocks. also changed deserialization to preserve original ordering such that: tosa1->deserialize->serialize->tosa2 => tosa1 == tosa2 most of the time. Signed-off-by: Tai Ly Change-Id: I539076788bdc466ba1881d955af349e6b4924ed8 --- src/TosaDeserialize.cpp | 290 +++++++++++++++++++++++++++++++++++------------- src/TosaSerialize.cpp | 270 +++++++++++++++++++++++--------------------- 2 files changed, 356 insertions(+), 204 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 3d38e1b..495d6f0 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -25,9 +25,9 @@ #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tosa_serialization_handler.h" #include -#include #include #include +#include #include // The namespace might be confusing here. We have mlir::tosa:: defined in MLIR @@ -155,10 +155,12 @@ public: TosaMlirOperatorBuilder( mlir::OpBuilder *_op_builder, TosaSerializationBasicBlock *_ser_block, mlir::Block *_block, mlir::Location _loc, + TosaMlirBlockBuilder *_block_builder, std::unordered_map *_tensor_map, std::unordered_map *_tensor_type_map) : op_builder(_op_builder), ser_block(_ser_block), block(_block), - loc(_loc), tensor_map(_tensor_map), tensor_type_map(_tensor_type_map) {} + loc(_loc), block_builder(_block_builder), tensor_map(_tensor_map), + tensor_type_map(_tensor_type_map) {} template std::vector build(TosaSerializationOperator *op) const; @@ -179,6 +181,9 @@ public: return op_string; } + TosaSerializationHandler *GetTsh() const; + TosaMlirRegionBuilder *GetRegionBuilder() const; + private: template std::vector @@ -199,6 +204,7 @@ private: TosaSerializationBasicBlock *ser_block; mlir::Block *block; mlir::Location loc; + TosaMlirBlockBuilder *block_builder; std::unordered_map *tensor_map; std::unordered_map *tensor_type_map; }; @@ -208,7 +214,7 @@ template std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const { llvm::errs() << "ERROR: " << get_string(op) << " translation hasn't been implemented\n"; - return std::vector(); + return {}; } // BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) @@ -428,7 +434,7 @@ std::vector TosaMlirOperatorBuilder::build(TosaSerializat default: llvm::errs() << "ERROR: " << get_string(op) << " contains unsupported element type\n"; - return std::vector(); + return {}; } mlir::Operation *mlir_op = op_builder->create(loc, output_type, value_attr); @@ -1020,25 +1026,6 @@ std::vector TosaMlirOperatorBuilder::build(TosaSerializa return std::vector({mlir_op->getResult(0)}); } -// TosaSerializationBasicBlock -// template <> -// std::vector -// TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const -//{ -// // todo: mlir::tosa::IfOp -// return {}; -//} - -// TosaSerializationBasicBlock -// template <> -// std::vector -// TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) -// const -//{ -// // todo: mlir::tosa::WhileOp -// return {}; -//} - template <> std::vector TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { @@ -1081,17 +1068,28 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { class TosaMlirRegionBuilder { public: - TosaMlirRegionBuilder(TosaSerializationRegion* _ser_region, - TosaSerializationHandler* _tsh, - mlir::Region* _region, - mlir::OpBuilder* _op_builder, - mlir::Location _loc) - : ser_region(_ser_region), tsh(_tsh), region(_region), op_builder(_op_builder), loc(_loc) {} + TosaMlirRegionBuilder(TosaSerializationRegion *_ser_region, + TosaSerializationHandler *_tsh, mlir::Region *_region, + mlir::OpBuilder *_op_builder, mlir::Location _loc, + TosaMlirRegionBuilder *_parent_value_scope = nullptr) + : ser_region(_ser_region), tsh(_tsh), region(_region), + op_builder(_op_builder), loc(_loc) { + if (_parent_value_scope) { + // inherit parent_value_scope's tensor_map + for (auto &kv : _parent_value_scope->GetTensorMap()) { + tensor_map.insert(kv); + } + } + } mlir::LogicalResult BuildAllBlocksInRegion(std::vector& return_values); mlir::OpBuilder* GetOpBuilder() { return op_builder; } mlir::Location GetLocation() { return loc; } + std::unordered_map &GetTensorMap() { + return tensor_map; + } + TosaSerializationHandler *GetTsh() const { return tsh; } private: mlir::Region* region; @@ -1099,7 +1097,7 @@ private: TosaSerializationHandler* tsh; mlir::OpBuilder* op_builder; mlir::Location loc; - std::vector block_builders; + std::unordered_map tensor_map; }; class TosaMlirBlockBuilder { @@ -1115,28 +1113,185 @@ public: mlir::OpBuilder* GetOpBuilder() { return region_builder->GetOpBuilder(); } mlir::Location GetLocation() { return region_builder->GetLocation(); } + std::unordered_map &GetTensorMap() { + return region_builder->GetTensorMap(); + } + + TosaSerializationHandler *GetTsh() const { return region_builder->GetTsh(); } + TosaMlirRegionBuilder *GetRegionBuilder() const { return region_builder; } private: TosaSerializationBasicBlock* ser_block; TosaMlirRegionBuilder* region_builder; mlir::Block* block; - std::unordered_map tensor_map; std::unordered_map tensor_type_map; }; +TosaSerializationHandler *TosaMlirOperatorBuilder::GetTsh() const { + return block_builder->GetTsh(); +} + +TosaMlirRegionBuilder *TosaMlirOperatorBuilder::GetRegionBuilder() const { + return block_builder->GetRegionBuilder(); +} + +// build control flow ops: + +namespace { + +mlir::LogicalResult +BuildRegion(TosaSerializationRegion *ser_region, TosaSerializationHandler *tsh, + mlir::Region *mlir_region, mlir::OpBuilder *op_builder, + mlir::Location loc, std::vector &return_values, + bool isolated_from_above = false, + TosaMlirRegionBuilder *parent_region_builder = nullptr) { + TosaMlirRegionBuilder *parent_value_scope = + isolated_from_above ? nullptr : parent_region_builder; + TosaMlirRegionBuilder region_builder(ser_region, tsh, mlir_region, op_builder, + loc, parent_value_scope); + return region_builder.BuildAllBlocksInRegion(return_values); +} + +} // namespace + +template <> +std::vector TosaMlirOperatorBuilder::build( + TosaSerializationOperator *op) const { + mlir::Value cond_val = tensor_map->at(op->GetInputTensorNames().at(0)); + std::vector input_values; + for (auto idx = 1u; idx < op->GetInputTensorNames().size(); idx++) { + input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx))); + } + std::vector output_types; + for (auto &name : op->GetInputTensorNames()) { + output_types.push_back(tensor_type_map->at(name)); + } + + assert(op->GetAttributeType() == + Attribute_CondIfAttribute); // double check attribute type + TosaCondIfAttribute *attr = + static_cast(op->GetAttribute()); + auto ser_then_region = GetTsh()->GetRegionByName(attr->then_branch()); + auto ser_else_region = GetTsh()->GetRegionByName(attr->else_branch()); + + if (!ser_then_region || !ser_else_region) { + llvm::errs() << "ERROR: " << get_string(op) + << " region serialization hasn't been implemented\n"; + return {}; + } + + mlir::Operation *mlir_op = op_builder->create( + loc, output_types, cond_val, input_values); + + const bool isolated_from_above = + mlir_op->hasTrait(); + mlir::Region &then_region = mlir_op->getRegion(0); + mlir::Region &else_region = mlir_op->getRegion(1); + + auto curr_region_builder = GetRegionBuilder(); + + std::vector then_returns, else_returns; + + if (BuildRegion(ser_then_region, GetTsh(), &then_region, op_builder, loc, + then_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (then_returns.size() != mlir_op->getNumResults()) { + llvm::errs() + << "ERROR: " << get_string(op) + << " then_region yield.size() doesn't match cond_if's output size\n"; + return {}; + } + + if (BuildRegion(ser_else_region, GetTsh(), &else_region, op_builder, loc, + else_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (else_returns.size() != mlir_op->getNumResults()) { + llvm::errs() + << "ERROR: " << get_string(op) + << " else_region yield.size() doesn't match cond_if's output size\n"; + return {}; + } + + block->push_back(mlir_op); + return std::vector(mlir_op->getResults().begin(), + mlir_op->getResults().end()); +} + +template <> +std::vector TosaMlirOperatorBuilder::build( + TosaSerializationOperator *op) const { + std::vector input_values; + for (auto idx = 0u; idx < op->GetInputTensorNames().size(); idx++) { + input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx))); + } + std::vector output_types; + for (auto &name : op->GetInputTensorNames()) { + output_types.push_back(tensor_type_map->at(name)); + } + assert(op->GetAttributeType() == + Attribute_WhileLoopAttribute); // double check attribute type + TosaWhileLoopAttribute *attr = + static_cast(op->GetAttribute()); + auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_branch()); + auto ser_body_region = GetTsh()->GetRegionByName(attr->body_branch()); + + mlir::Operation *mlir_op = + op_builder->create(loc, output_types, input_values); + + const bool isolated_from_above = + mlir_op->hasTrait(); + + mlir::Region &cond_region = mlir_op->getRegion(0); + mlir::Region &body_region = mlir_op->getRegion(1); + + auto curr_region_builder = GetRegionBuilder(); + + std::vector cond_returns, body_returns; + + if (BuildRegion(ser_cond_region, GetTsh(), &cond_region, op_builder, loc, + cond_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (cond_returns.size() != 1) { + llvm::errs() << "ERROR: " << get_string(op) + << " cond_region yield.size() is not 1\n"; + return {}; + } + + if (BuildRegion(ser_body_region, GetTsh(), &body_region, op_builder, loc, + body_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (body_returns.size() != mlir_op->getNumResults()) { + llvm::errs() + << "ERROR: " << get_string(op) + << " body_region yield.size() doesn't match while_loop's output size\n"; + return {}; + } + + block->push_back(mlir_op); + return std::vector(mlir_op->getResults().begin(), + mlir_op->getResults().end()); +} + mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( std::vector &return_values) { block->clear(); auto loc = GetLocation(); auto op_builder = GetOpBuilder(); + auto &tensor_map = GetTensorMap(); - std::unordered_map> consumer_map; - std::unordered_map tensor_built; - std::unordered_map operator_built; + std::unordered_set operator_built; std::queue operator_queue; TosaMlirOperatorBuilder tosa_op_builder(op_builder, ser_block, block, loc, - &tensor_map, &tensor_type_map); + this, &tensor_map, &tensor_type_map); for (auto ts : ser_block->GetTensors()) { mlir::RankedTensorType type; @@ -1145,28 +1300,19 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( } const auto& ts_name = ts->GetName(); tensor_type_map[ts_name] = type; - tensor_built[ts_name] = false; } - for (auto op : ser_block->GetOperators()) { - operator_built[op] = false; - for (auto ts_name : op->GetInputTensorNames()) { - consumer_map[ts_name].push_back(op); - } - } - - // Update operator_queue if a consumer of tensor_name has all of its inputs already built - auto queue_ready_consumers = [&](const std::string tensor_name) { - for (auto consumer_op : consumer_map[tensor_name]) { - // Sanity check operator hasn't been built - if (operator_built[consumer_op]) { - llvm::errs() << "ERROR: " << tosa_op_builder.get_string(consumer_op) - << " is already built before its input is built\n"; - assert(0); + // Update operator_queue with operators whose inputs are all built + auto queue_ready_operators = [&]() { + // note: it is important to queue in order of original operators + // because an operator input may be defined in upper value scopes + for (auto consumer_op : ser_block->GetOperators()) { + if (operator_built.count(consumer_op)) { + continue; } bool all_inputs_ready = true; for (const auto& input_name : consumer_op->GetInputTensorNames()) { - if (!tensor_built[input_name]) { + if (!tensor_map.count(input_name)) { all_inputs_ready = false; break; } @@ -1177,7 +1323,7 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( } }; - // Initialize tensor_map/tensor_built/operator_queue based on block input arguments + // Initialize tensor_map/operator_queue based on block input arguments for (const std::string& block_input_name : ser_block->GetInputs()) { auto type = tensor_type_map[block_input_name]; auto input_value = block->addArgument(type, loc); @@ -1185,29 +1331,21 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( llvm::errs() << "ERROR: block input tensor " << block_input_name << " already exists\n"; return mlir::failure(); } - tensor_built[block_input_name] = true; tensor_map[block_input_name] = input_value; - queue_ready_consumers(block_input_name); } - // add all operators with 0 inputs (e.g., constant operators) to - // operator_queue - for (auto op : ser_block->GetOperators()) { - if (op->GetInputTensorNames().empty()) { - operator_queue.push(op); - } - } + queue_ready_operators(); while (!operator_queue.empty()) { TosaSerializationOperator* op = operator_queue.front(); operator_queue.pop(); // skip if operator has been built - if (operator_built[op]) { + if (operator_built.count(op)) { // this happens when same input appears twice or more in operator, eg, concat(%0, %0) continue; } - operator_built[op] = true; + operator_built.insert(op); std::vector output_values; if (false) { @@ -1232,22 +1370,23 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( for (size_t i = 0; i < output_values.size(); i++) { // Sanity check tensor hasn't been built std::string op_output_name = op->GetOutputTensorNames()[i]; - if (tensor_built[op_output_name]) { + if (tensor_map.count(op_output_name)) { llvm::errs() << "ERROR: tensor " << op_output_name << " is already built\n"; return mlir::failure(); } tensor_map[op_output_name] = output_values[i]; - tensor_built[op_output_name] = true; - queue_ready_consumers(op_output_name); } + // look for any more ready consumers + queue_ready_operators(); } // Construct return values std::vector return_operands; for (const auto& output_name : ser_block->GetOutputs()) { // Sanity check if terminator mlir::Value is built - if (!tensor_built[output_name]) { - llvm::errs() << "ERROR: terminator mlir::Value " << output_name << " is not built\n"; + if (!tensor_map.count(output_name)) { + llvm::errs() << "ERROR: terminator mlir::Value " << output_name + << " is not built in block " << ser_block->GetName() << "\n"; return mlir::failure(); } mlir::Value output_value = tensor_map.at(output_name); @@ -1267,8 +1406,6 @@ mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion( for (auto& ser_block : ser_region->GetBlocks()) { auto& block = region->emplaceBlock(); TosaMlirBlockBuilder block_builder(ser_block, this, &block); - // Region Builders need access to block builders (?) - block_builders.push_back(&block_builder); if (block_builder.BuildAllOpsInBlock(return_values).failed()) { return mlir::failure(); @@ -1291,12 +1428,6 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func, return mlir::failure(); } - if (tsh.GetRegions().size() != 1) { - llvm::errs() << "Internal Error: TosaSerializationHandler's region list " - "must contain exactly one region\n"; - return mlir::failure(); - } - TosaSerializationRegion* ser_main_region = tsh.GetRegions().front(); auto loc = func.getLoc(); @@ -1305,8 +1436,9 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func, main_region->takeBody(*main_region); // empty old func body auto op_builder = mlir::OpBuilder(func.getBody()); - TosaMlirRegionBuilder region_builder(ser_main_region, &tsh, main_region, &op_builder, loc); - if (region_builder.BuildAllBlocksInRegion(main_returns).failed()) { + if (BuildRegion(ser_main_region, &tsh, main_region, &op_builder, loc, + main_returns) + .failed()) { return mlir::failure(); } diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 8a95b68..4fd5014 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -123,8 +123,11 @@ public: TosaSerializationOperatorBuilder( TosaSerializationBlockBuilder *_block_builder) : block_builder(_block_builder) {} + template TosaSerializationOperator *build(mlir::Operation &op) const; + TosaSerializationHandler *GetTsh() const; + TosaSerializationRegionBuilder *GetRegionBuilder() const; private: std::string GetTensorName(mlir::Value val) const; @@ -142,8 +145,7 @@ private: // This builder assumes each region only has only one block class TosaSerializationBlockBuilder { public: - friend class TosaSerializationOperatorBuilder; - + // constructor TosaSerializationBlockBuilder(TosaSerializationBasicBlock* _ser_block, TosaSerializationRegionBuilder* _region_builder, mlir::Block* _block) @@ -151,8 +153,14 @@ public: mlir::LogicalResult BuildAllOpsInBlock(std::vector& return_values); - TosaSerializationBasicBlock* GetBlock() { return ser_block; } - TosaSerializationRegionBuilder* GetRegionBuilder() { return region_builder; } + TosaSerializationBasicBlock *GetBlock() const { return ser_block; } + TosaSerializationRegionBuilder *GetRegionBuilder() const { + return region_builder; + } + TosaSerializationHandler *GetTsh() const; + std::unordered_map &GetTensorMap() { + return tensor_map; + } private: TosaSerializationOperator *BuildTosaSerializationOperator( @@ -169,33 +177,54 @@ private: class TosaSerializationRegionBuilder { public: - friend class TosaSerializationBlockBuilder; - friend class TosaSerializationOperatorBuilder; - // Constructor - TosaSerializationRegionBuilder(TosaSerializationRegion* _ser_region, - TosaSerializationHandler* _tsh, - mlir::Region* _region) - : ser_region(_ser_region), tsh(_tsh), region(_region) {} - TosaSerializationHandler* GetTsh() { return tsh; } - mlir::LogicalResult BuildAllBlocksInRegion(std::vector& return_values); - - int getNumBlocksInRegion() const { return ser_region->GetBlocks().size(); } + TosaSerializationRegionBuilder( + TosaSerializationRegion *_ser_region, mlir::Region *_region, + TosaSerializationRegionBuilder *_parent_value_scope, + TosaSerializationHandler *_tsh) + : ser_region(_ser_region), region(_region), + parent_value_scope(_parent_value_scope), tsh(_tsh) {} + TosaSerializationHandler *GetTsh() const { return tsh; } + mlir::LogicalResult + BuildAllBlocksInRegion(bool is_top, std::vector &return_values); + TosaSerializationRegionBuilder *GetParentValueScope() const { + return parent_value_scope; + } + std::vector &GetBlockBuilders() { + return block_builders; + } private: - mlir::Region* region; TosaSerializationRegion* ser_region; + mlir::Region *region; + TosaSerializationRegionBuilder *parent_value_scope; TosaSerializationHandler* tsh; std::vector block_builders; }; +TosaSerializationHandler *TosaSerializationOperatorBuilder::GetTsh() const { + return block_builder->GetTsh(); +} +TosaSerializationHandler *TosaSerializationBlockBuilder::GetTsh() const { + return region_builder->GetTsh(); +} +TosaSerializationRegionBuilder * +TosaSerializationOperatorBuilder::GetRegionBuilder() const { + return block_builder->GetRegionBuilder(); +} + std::string TosaSerializationOperatorBuilder::GetTensorName(mlir::Value val) const { - // Traverse through each block builder in the region - for (auto curr_block_builder : block_builder->region_builder->block_builders) { - if (curr_block_builder->tensor_map.find(val) != curr_block_builder->tensor_map.end()) { - return curr_block_builder->tensor_map[val]; + auto value_scope = GetRegionBuilder(); + while (value_scope) { + // Traverse through each block builder in the region + for (auto curr_block_builder : value_scope->GetBlockBuilders()) { + const auto &tensor_map = curr_block_builder->GetTensorMap(); + if (tensor_map.count(val)) { + return tensor_map.at(val); + } } + value_scope = value_scope->GetParentValueScope(); } // Didn't find anything llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map\n"; @@ -1195,32 +1224,57 @@ TosaSerializationOperatorBuilder::build( return tyop; } +namespace { + +// serialize a region and all its blocks, and return region's return values +TosaSerializationRegion * +BuildRegion(mlir::Region ®ion, const std::string region_name, + const bool isolated_from_above, + TosaSerializationRegionBuilder *curr_region_builder, + TosaSerializationHandler *tsh, + std::vector &return_values, bool is_top = false) { + TosaSerializationRegion *ser_region = + new TosaSerializationRegion(region_name, {}); + assert(ser_region); + tsh->GetRegions().push_back(ser_region); + + TosaSerializationRegionBuilder *parent_value_scope = + isolated_from_above ? nullptr : curr_region_builder; + + TosaSerializationRegionBuilder region_builder(ser_region, ®ion, + parent_value_scope, tsh); + if (region_builder.BuildAllBlocksInRegion(is_top, return_values).failed()) { + return nullptr; + } + return ser_region; +} + +static int input_tensor_index = 0; +static int intermediate_tensor_index = 0; +static int output_tensor_index = 0; + +} // namespace + template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { + const std::string op_name = op.getName().getStringRef().str(); + const bool isolated_from_above = + op.hasTrait(); + auto curr_region_builder = GetRegionBuilder(); std::vector input_names, output_names; - mlir::Block& then_block = op.getRegion(0).front(); - mlir::Block& else_block = op.getRegion(1).front(); std::vector then_yields, else_yields; - TosaSerializationBasicBlock* ser_then_block = nullptr; - TosaSerializationBasicBlock* ser_else_block = nullptr; - - // Building then branch block - std::string region_name = block_builder->region_builder->ser_region->GetName(); - std::string then_block_name = - "bb" + std::to_string(block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().size()); - ser_then_block = new TosaSerializationBasicBlock( - then_block_name, region_name, std::vector(), - std::vector(), std::vector(), - std::vector()); - assert(ser_then_block); - block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().push_back(ser_then_block); - - TosaSerializationBlockBuilder then_block_builder( - ser_then_block, block_builder->region_builder, &then_block); - block_builder->region_builder->block_builders.push_back(&then_block_builder); - if (then_block_builder.BuildAllOpsInBlock(then_yields).failed()) { + auto tsh = GetTsh(); + + mlir::Region &then_region = op.getRegion(0); + mlir::Region &else_region = op.getRegion(1); + + const std::string then_region_name = op_name + "_then_region"; + TosaSerializationRegion *ser_then_region = + BuildRegion(then_region, then_region_name, isolated_from_above, + curr_region_builder, tsh, then_yields); + if (!ser_then_region) { return nullptr; } if (then_yields.size() != op.getNumResults()) { @@ -1229,20 +1283,11 @@ TosaSerializationOperatorBuilder::build( return nullptr; } - // Building else branch block - std::string else_block_name = - "bb" + std::to_string(block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().size()); - ser_else_block = new TosaSerializationBasicBlock( - else_block_name, region_name, std::vector(), - std::vector(), std::vector(), - std::vector()); - assert(ser_else_block); - block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().push_back(ser_else_block); - - TosaSerializationBlockBuilder else_block_builder( - ser_else_block, block_builder->region_builder, &else_block); - block_builder->region_builder->block_builders.push_back(&else_block_builder); - if (else_block_builder.BuildAllOpsInBlock(else_yields).failed()) { + const std::string else_region_name = op_name + "_else_region"; + TosaSerializationRegion *ser_else_region = + BuildRegion(else_region, else_region_name, isolated_from_above, + curr_region_builder, tsh, else_yields); + if (!ser_else_region) { return nullptr; } if (else_yields.size() != op.getNumResults()) { @@ -1251,7 +1296,7 @@ TosaSerializationOperatorBuilder::build( return nullptr; } - TosaCondIfAttribute attribute(ser_then_block->GetName(), ser_else_block->GetName()); + TosaCondIfAttribute attribute(then_region_name, else_region_name); for (size_t i = 0; i < op.getNumOperands(); i++) { std::string input_name = GetTensorName(op.getOperand(i)); @@ -1263,9 +1308,9 @@ TosaSerializationOperatorBuilder::build( output_names.push_back(output_name); } - TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_COND_IF, Attribute_CondIfAttribute, &attribute, - input_names, output_names); + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_COND_IF, Attribute_CondIfAttribute, + &attribute, input_names, output_names); return tyop; } @@ -1274,32 +1319,22 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { + const std::string op_name = op.getName().getStringRef().str(); + const bool isolated_from_above = + op.hasTrait(); + auto curr_region_builder = GetRegionBuilder(); std::vector input_names, output_names; + auto tsh = GetTsh(); - mlir::Block& cond_block = op.getRegion(0).front(); - mlir::Block& body_block = op.getRegion(1).front(); + mlir::Region &cond_region = op.getRegion(0); + mlir::Region &body_region = op.getRegion(1); std::vector cond_yields, body_yields; - TosaSerializationBasicBlock* ser_cond_block = nullptr; - TosaSerializationBasicBlock* ser_body_block = nullptr; - - // Building cond branch block - std::string cond_block_name = - "bb" + std::to_string(block_builder->region_builder->getNumBlocksInRegion()); - - std::string region_name = block_builder->region_builder->ser_region->GetName(); - ser_cond_block = new TosaSerializationBasicBlock( - cond_block_name, region_name, std::vector(), - std::vector(), std::vector(), - std::vector()); - assert(ser_cond_block); - block_builder->region_builder->ser_region->GetBlocks().push_back(ser_cond_block); - - TosaSerializationBlockBuilder cond_block_builder( - ser_cond_block, block_builder->region_builder, &cond_block); - block_builder->region_builder->block_builders.push_back(&cond_block_builder); - - if (cond_block_builder.BuildAllOpsInBlock(cond_yields).failed()) { + const std::string cond_region_name = op_name + "_cond_region"; + TosaSerializationRegion *ser_cond_region = + BuildRegion(cond_region, cond_region_name, isolated_from_above, + curr_region_builder, tsh, cond_yields); + if (!ser_cond_region) { return nullptr; } if (cond_yields.size() != 1) { @@ -1307,21 +1342,11 @@ TosaSerializationOperatorBuilder::build( return nullptr; } - - // Building body branch block - std::string body_block_name = - "bb" + std::to_string(block_builder->region_builder->getNumBlocksInRegion()); - ser_body_block = new TosaSerializationBasicBlock( - body_block_name, region_name, std::vector(), - std::vector(), std::vector(), - std::vector()); - assert(ser_body_block); - block_builder->region_builder->ser_region->GetBlocks().push_back(ser_body_block); - - TosaSerializationBlockBuilder body_block_builder( - ser_body_block, block_builder->region_builder, &body_block); - block_builder->region_builder->block_builders.push_back(&body_block_builder); - if (body_block_builder.BuildAllOpsInBlock(body_yields).failed()) { + const std::string body_region_name = op_name + "_body_region"; + TosaSerializationRegion *ser_body_region = + BuildRegion(body_region, body_region_name, isolated_from_above, + curr_region_builder, tsh, body_yields); + if (!ser_body_region) { return nullptr; } if (body_yields.size() != op.getNumResults()) { @@ -1330,9 +1355,7 @@ TosaSerializationOperatorBuilder::build( return nullptr; } - - TosaWhileLoopAttribute attribute(ser_cond_block->GetName(), - ser_body_block->GetName()); + TosaWhileLoopAttribute attribute(cond_region_name, body_region_name); for (size_t i = 0; i < op.getNumOperands(); i++) { std::string input_name = GetTensorName(op.getOperand(i)); @@ -1344,9 +1367,9 @@ TosaSerializationOperatorBuilder::build( output_names.push_back(output_name); } - TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute, - input_names, output_names); + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_WHILE_LOOP, Attribute_WhileLoopAttribute, + &attribute, input_names, output_names); return tyop; } @@ -1390,16 +1413,20 @@ TosaSerializationOperatorBuilder::build( } /* End translating TOSA operator */ -mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(std::vector& return_values) { +mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion( + bool is_top, std::vector &return_values) { std::string region_name = ser_region->GetName(); - // this will likely run once for most cases. + int block_index = 0; for (auto& block : this->region->getBlocks()) { - // TODO: update the block name - TosaSerializationBasicBlock* ser_block = new TosaSerializationBasicBlock( - std::string("main"), region_name, std::vector(), - std::vector(), std::vector(), - std::vector() - ); + // must name first block of top region "main" + const std::string block_name = + (is_top && block_index == 0) + ? "main" + : (region_name + "_bb" + std::to_string(block_index++)); + TosaSerializationBasicBlock *ser_block = new TosaSerializationBasicBlock( + block_name, region_name, std::vector(), + std::vector(), std::vector(), + std::vector()); // build the block TosaSerializationBlockBuilder block_builder(ser_block, this, &block); @@ -1421,16 +1448,11 @@ mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(std:: return mlir::success(); } - - mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( std::vector &return_values) { TosaSerializationOperator *ser_operator = nullptr; TosaSerializationTensor *ser_tensor = nullptr; size_t num_blocks_in_region = 0; - static int input_tensor_index = 0; - static int intermediate_tensor_index = 0; - static int output_tensor_index = 0; TosaSerializationOperatorBuilder op_builder(this); // Specify block input tensor name @@ -1528,7 +1550,6 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( TosaSerializationOperator * TosaSerializationBlockBuilder::BuildTosaSerializationOperator( const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) { - std::string full_op_name = op.getName().getStringRef().str(); TosaSerializationOperator *target_operator = nullptr; if (false) { @@ -1600,8 +1621,6 @@ TosaSerializationBlockBuilder::BuildTosaSerializationTensor( mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func, TosaSerializationHandler &tsh) { - TosaSerializationRegion* ser_main_region; - mlir::Region *main_region = func.getCallableRegion(); std::vector main_returns; @@ -1616,15 +1635,16 @@ mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func, return mlir::failure(); } - ser_main_region = new TosaSerializationRegion( - std::string("main"), /* region_name */ - std::vector() /* empty serialized block container */ - ); - assert(ser_main_region); - tsh.GetRegions().push_back(ser_main_region); + // reset static counters + input_tensor_index = 0; + intermediate_tensor_index = 0; + output_tensor_index = 0; - TosaSerializationRegionBuilder region_builder(ser_main_region, &tsh, main_region); - if (region_builder.BuildAllBlocksInRegion(main_returns).failed()) { + TosaSerializationRegion *ser_main_region = + BuildRegion(*main_region, "main", /* isolated_from_above = */ true, + /* parent_value_scope = */ nullptr, &tsh, main_returns, + /* is_top = */ true); + if (!ser_main_region) { return mlir::failure(); } -- cgit v1.2.1