From 8ffce6d63372ea93f3e060571137bce11d4735d8 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 16 Mar 2023 22:22:41 +0000 Subject: serialize/deserialize while/if ops using regions This changes serialization and deserialization of while and if ops to use regions instead of blocks. also changed deserialization to preserve original ordering such that: tosa1->deserialize->serialize->tosa2 => tosa1 == tosa2 most of the time. Signed-off-by: Tai Ly Change-Id: I539076788bdc466ba1881d955af349e6b4924ed8 --- src/TosaSerialize.cpp | 270 +++++++++++++++++++++++++++----------------------- 1 file changed, 145 insertions(+), 125 deletions(-) (limited to 'src/TosaSerialize.cpp') diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 8a95b68..4fd5014 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -123,8 +123,11 @@ public: TosaSerializationOperatorBuilder( TosaSerializationBlockBuilder *_block_builder) : block_builder(_block_builder) {} + template TosaSerializationOperator *build(mlir::Operation &op) const; + TosaSerializationHandler *GetTsh() const; + TosaSerializationRegionBuilder *GetRegionBuilder() const; private: std::string GetTensorName(mlir::Value val) const; @@ -142,8 +145,7 @@ private: // This builder assumes each region only has only one block class TosaSerializationBlockBuilder { public: - friend class TosaSerializationOperatorBuilder; - + // constructor TosaSerializationBlockBuilder(TosaSerializationBasicBlock* _ser_block, TosaSerializationRegionBuilder* _region_builder, mlir::Block* _block) @@ -151,8 +153,14 @@ public: mlir::LogicalResult BuildAllOpsInBlock(std::vector& return_values); - TosaSerializationBasicBlock* GetBlock() { return ser_block; } - TosaSerializationRegionBuilder* GetRegionBuilder() { return region_builder; } + TosaSerializationBasicBlock *GetBlock() const { return ser_block; } + TosaSerializationRegionBuilder *GetRegionBuilder() const { + return region_builder; + } + TosaSerializationHandler *GetTsh() const; + std::unordered_map &GetTensorMap() { + return tensor_map; + } private: TosaSerializationOperator *BuildTosaSerializationOperator( @@ -169,33 +177,54 @@ private: class TosaSerializationRegionBuilder { public: - friend class TosaSerializationBlockBuilder; - friend class TosaSerializationOperatorBuilder; - // Constructor - TosaSerializationRegionBuilder(TosaSerializationRegion* _ser_region, - TosaSerializationHandler* _tsh, - mlir::Region* _region) - : ser_region(_ser_region), tsh(_tsh), region(_region) {} - TosaSerializationHandler* GetTsh() { return tsh; } - mlir::LogicalResult BuildAllBlocksInRegion(std::vector& return_values); - - int getNumBlocksInRegion() const { return ser_region->GetBlocks().size(); } + TosaSerializationRegionBuilder( + TosaSerializationRegion *_ser_region, mlir::Region *_region, + TosaSerializationRegionBuilder *_parent_value_scope, + TosaSerializationHandler *_tsh) + : ser_region(_ser_region), region(_region), + parent_value_scope(_parent_value_scope), tsh(_tsh) {} + TosaSerializationHandler *GetTsh() const { return tsh; } + mlir::LogicalResult + BuildAllBlocksInRegion(bool is_top, std::vector &return_values); + TosaSerializationRegionBuilder *GetParentValueScope() const { + return parent_value_scope; + } + std::vector &GetBlockBuilders() { + return block_builders; + } private: - mlir::Region* region; TosaSerializationRegion* ser_region; + mlir::Region *region; + TosaSerializationRegionBuilder *parent_value_scope; TosaSerializationHandler* tsh; std::vector block_builders; }; +TosaSerializationHandler *TosaSerializationOperatorBuilder::GetTsh() const { + return block_builder->GetTsh(); +} +TosaSerializationHandler *TosaSerializationBlockBuilder::GetTsh() const { + return region_builder->GetTsh(); +} +TosaSerializationRegionBuilder * +TosaSerializationOperatorBuilder::GetRegionBuilder() const { + return block_builder->GetRegionBuilder(); +} + std::string TosaSerializationOperatorBuilder::GetTensorName(mlir::Value val) const { - // Traverse through each block builder in the region - for (auto curr_block_builder : block_builder->region_builder->block_builders) { - if (curr_block_builder->tensor_map.find(val) != curr_block_builder->tensor_map.end()) { - return curr_block_builder->tensor_map[val]; + auto value_scope = GetRegionBuilder(); + while (value_scope) { + // Traverse through each block builder in the region + for (auto curr_block_builder : value_scope->GetBlockBuilders()) { + const auto &tensor_map = curr_block_builder->GetTensorMap(); + if (tensor_map.count(val)) { + return tensor_map.at(val); + } } + value_scope = value_scope->GetParentValueScope(); } // Didn't find anything llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map\n"; @@ -1195,32 +1224,57 @@ TosaSerializationOperatorBuilder::build( return tyop; } +namespace { + +// serialize a region and all its blocks, and return region's return values +TosaSerializationRegion * +BuildRegion(mlir::Region ®ion, const std::string region_name, + const bool isolated_from_above, + TosaSerializationRegionBuilder *curr_region_builder, + TosaSerializationHandler *tsh, + std::vector &return_values, bool is_top = false) { + TosaSerializationRegion *ser_region = + new TosaSerializationRegion(region_name, {}); + assert(ser_region); + tsh->GetRegions().push_back(ser_region); + + TosaSerializationRegionBuilder *parent_value_scope = + isolated_from_above ? nullptr : curr_region_builder; + + TosaSerializationRegionBuilder region_builder(ser_region, ®ion, + parent_value_scope, tsh); + if (region_builder.BuildAllBlocksInRegion(is_top, return_values).failed()) { + return nullptr; + } + return ser_region; +} + +static int input_tensor_index = 0; +static int intermediate_tensor_index = 0; +static int output_tensor_index = 0; + +} // namespace + template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { + const std::string op_name = op.getName().getStringRef().str(); + const bool isolated_from_above = + op.hasTrait(); + auto curr_region_builder = GetRegionBuilder(); std::vector input_names, output_names; - mlir::Block& then_block = op.getRegion(0).front(); - mlir::Block& else_block = op.getRegion(1).front(); std::vector then_yields, else_yields; - TosaSerializationBasicBlock* ser_then_block = nullptr; - TosaSerializationBasicBlock* ser_else_block = nullptr; - - // Building then branch block - std::string region_name = block_builder->region_builder->ser_region->GetName(); - std::string then_block_name = - "bb" + std::to_string(block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().size()); - ser_then_block = new TosaSerializationBasicBlock( - then_block_name, region_name, std::vector(), - std::vector(), std::vector(), - std::vector()); - assert(ser_then_block); - block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().push_back(ser_then_block); - - TosaSerializationBlockBuilder then_block_builder( - ser_then_block, block_builder->region_builder, &then_block); - block_builder->region_builder->block_builders.push_back(&then_block_builder); - if (then_block_builder.BuildAllOpsInBlock(then_yields).failed()) { + auto tsh = GetTsh(); + + mlir::Region &then_region = op.getRegion(0); + mlir::Region &else_region = op.getRegion(1); + + const std::string then_region_name = op_name + "_then_region"; + TosaSerializationRegion *ser_then_region = + BuildRegion(then_region, then_region_name, isolated_from_above, + curr_region_builder, tsh, then_yields); + if (!ser_then_region) { return nullptr; } if (then_yields.size() != op.getNumResults()) { @@ -1229,20 +1283,11 @@ TosaSerializationOperatorBuilder::build( return nullptr; } - // Building else branch block - std::string else_block_name = - "bb" + std::to_string(block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().size()); - ser_else_block = new TosaSerializationBasicBlock( - else_block_name, region_name, std::vector(), - std::vector(), std::vector(), - std::vector()); - assert(ser_else_block); - block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().push_back(ser_else_block); - - TosaSerializationBlockBuilder else_block_builder( - ser_else_block, block_builder->region_builder, &else_block); - block_builder->region_builder->block_builders.push_back(&else_block_builder); - if (else_block_builder.BuildAllOpsInBlock(else_yields).failed()) { + const std::string else_region_name = op_name + "_else_region"; + TosaSerializationRegion *ser_else_region = + BuildRegion(else_region, else_region_name, isolated_from_above, + curr_region_builder, tsh, else_yields); + if (!ser_else_region) { return nullptr; } if (else_yields.size() != op.getNumResults()) { @@ -1251,7 +1296,7 @@ TosaSerializationOperatorBuilder::build( return nullptr; } - TosaCondIfAttribute attribute(ser_then_block->GetName(), ser_else_block->GetName()); + TosaCondIfAttribute attribute(then_region_name, else_region_name); for (size_t i = 0; i < op.getNumOperands(); i++) { std::string input_name = GetTensorName(op.getOperand(i)); @@ -1263,9 +1308,9 @@ TosaSerializationOperatorBuilder::build( output_names.push_back(output_name); } - TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_COND_IF, Attribute_CondIfAttribute, &attribute, - input_names, output_names); + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_COND_IF, Attribute_CondIfAttribute, + &attribute, input_names, output_names); return tyop; } @@ -1274,32 +1319,22 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { + const std::string op_name = op.getName().getStringRef().str(); + const bool isolated_from_above = + op.hasTrait(); + auto curr_region_builder = GetRegionBuilder(); std::vector input_names, output_names; + auto tsh = GetTsh(); - mlir::Block& cond_block = op.getRegion(0).front(); - mlir::Block& body_block = op.getRegion(1).front(); + mlir::Region &cond_region = op.getRegion(0); + mlir::Region &body_region = op.getRegion(1); std::vector cond_yields, body_yields; - TosaSerializationBasicBlock* ser_cond_block = nullptr; - TosaSerializationBasicBlock* ser_body_block = nullptr; - - // Building cond branch block - std::string cond_block_name = - "bb" + std::to_string(block_builder->region_builder->getNumBlocksInRegion()); - - std::string region_name = block_builder->region_builder->ser_region->GetName(); - ser_cond_block = new TosaSerializationBasicBlock( - cond_block_name, region_name, std::vector(), - std::vector(), std::vector(), - std::vector()); - assert(ser_cond_block); - block_builder->region_builder->ser_region->GetBlocks().push_back(ser_cond_block); - - TosaSerializationBlockBuilder cond_block_builder( - ser_cond_block, block_builder->region_builder, &cond_block); - block_builder->region_builder->block_builders.push_back(&cond_block_builder); - - if (cond_block_builder.BuildAllOpsInBlock(cond_yields).failed()) { + const std::string cond_region_name = op_name + "_cond_region"; + TosaSerializationRegion *ser_cond_region = + BuildRegion(cond_region, cond_region_name, isolated_from_above, + curr_region_builder, tsh, cond_yields); + if (!ser_cond_region) { return nullptr; } if (cond_yields.size() != 1) { @@ -1307,21 +1342,11 @@ TosaSerializationOperatorBuilder::build( return nullptr; } - - // Building body branch block - std::string body_block_name = - "bb" + std::to_string(block_builder->region_builder->getNumBlocksInRegion()); - ser_body_block = new TosaSerializationBasicBlock( - body_block_name, region_name, std::vector(), - std::vector(), std::vector(), - std::vector()); - assert(ser_body_block); - block_builder->region_builder->ser_region->GetBlocks().push_back(ser_body_block); - - TosaSerializationBlockBuilder body_block_builder( - ser_body_block, block_builder->region_builder, &body_block); - block_builder->region_builder->block_builders.push_back(&body_block_builder); - if (body_block_builder.BuildAllOpsInBlock(body_yields).failed()) { + const std::string body_region_name = op_name + "_body_region"; + TosaSerializationRegion *ser_body_region = + BuildRegion(body_region, body_region_name, isolated_from_above, + curr_region_builder, tsh, body_yields); + if (!ser_body_region) { return nullptr; } if (body_yields.size() != op.getNumResults()) { @@ -1330,9 +1355,7 @@ TosaSerializationOperatorBuilder::build( return nullptr; } - - TosaWhileLoopAttribute attribute(ser_cond_block->GetName(), - ser_body_block->GetName()); + TosaWhileLoopAttribute attribute(cond_region_name, body_region_name); for (size_t i = 0; i < op.getNumOperands(); i++) { std::string input_name = GetTensorName(op.getOperand(i)); @@ -1344,9 +1367,9 @@ TosaSerializationOperatorBuilder::build( output_names.push_back(output_name); } - TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute, - input_names, output_names); + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_WHILE_LOOP, Attribute_WhileLoopAttribute, + &attribute, input_names, output_names); return tyop; } @@ -1390,16 +1413,20 @@ TosaSerializationOperatorBuilder::build( } /* End translating TOSA operator */ -mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(std::vector& return_values) { +mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion( + bool is_top, std::vector &return_values) { std::string region_name = ser_region->GetName(); - // this will likely run once for most cases. + int block_index = 0; for (auto& block : this->region->getBlocks()) { - // TODO: update the block name - TosaSerializationBasicBlock* ser_block = new TosaSerializationBasicBlock( - std::string("main"), region_name, std::vector(), - std::vector(), std::vector(), - std::vector() - ); + // must name first block of top region "main" + const std::string block_name = + (is_top && block_index == 0) + ? "main" + : (region_name + "_bb" + std::to_string(block_index++)); + TosaSerializationBasicBlock *ser_block = new TosaSerializationBasicBlock( + block_name, region_name, std::vector(), + std::vector(), std::vector(), + std::vector()); // build the block TosaSerializationBlockBuilder block_builder(ser_block, this, &block); @@ -1421,16 +1448,11 @@ mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(std:: return mlir::success(); } - - mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( std::vector &return_values) { TosaSerializationOperator *ser_operator = nullptr; TosaSerializationTensor *ser_tensor = nullptr; size_t num_blocks_in_region = 0; - static int input_tensor_index = 0; - static int intermediate_tensor_index = 0; - static int output_tensor_index = 0; TosaSerializationOperatorBuilder op_builder(this); // Specify block input tensor name @@ -1528,7 +1550,6 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( TosaSerializationOperator * TosaSerializationBlockBuilder::BuildTosaSerializationOperator( const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) { - std::string full_op_name = op.getName().getStringRef().str(); TosaSerializationOperator *target_operator = nullptr; if (false) { @@ -1600,8 +1621,6 @@ TosaSerializationBlockBuilder::BuildTosaSerializationTensor( mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func, TosaSerializationHandler &tsh) { - TosaSerializationRegion* ser_main_region; - mlir::Region *main_region = func.getCallableRegion(); std::vector main_returns; @@ -1616,15 +1635,16 @@ mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func, return mlir::failure(); } - ser_main_region = new TosaSerializationRegion( - std::string("main"), /* region_name */ - std::vector() /* empty serialized block container */ - ); - assert(ser_main_region); - tsh.GetRegions().push_back(ser_main_region); + // reset static counters + input_tensor_index = 0; + intermediate_tensor_index = 0; + output_tensor_index = 0; - TosaSerializationRegionBuilder region_builder(ser_main_region, &tsh, main_region); - if (region_builder.BuildAllBlocksInRegion(main_returns).failed()) { + TosaSerializationRegion *ser_main_region = + BuildRegion(*main_region, "main", /* isolated_from_above = */ true, + /* parent_value_scope = */ nullptr, &tsh, main_returns, + /* is_top = */ true); + if (!ser_main_region) { return mlir::failure(); } -- cgit v1.2.1