From ab75f53fede64241cb71acbf3b9b0c7d325162d8 Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Fri, 21 Oct 2022 10:49:48 -0700 Subject: Add RegionBuilder to TOSA MLIR Translator Rationale for making this change: - The original design only supports a single basicBlock which is no longer functionaly enough to support Control Flow operators like WhileOp or IFOp - Added another layer of abstraction of Region to support multiple basicBlocks + other corresponding fixes - There are other companion patches to make the above proposal work - Serialization Lib: Add TosaSerializationRegion to serialization_lib - Reference Model: Reference model update for control flow Signed-off-by: Jerry Ge Change-Id: Ic7eec3c32da87d409819365ba2dc7ef8b9619db4 --- src/TosaSerialize.cpp | 378 ++++++++++++++++++++++++------------------ third_party/serialization_lib | 2 +- 2 files changed, 220 insertions(+), 160 deletions(-) diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 625a922..67fd1c5 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 with LLVM Exceptions // (the "License"); you may not use this file except in compliance with @@ -116,6 +116,7 @@ static DType Type2PoolAccumDType(mlir::Type element_type) { return DType_UNKNOWN; } class TosaSerializationBlockBuilder; +class TosaSerializationRegionBuilder; class TosaSerializationOperatorBuilder { public: @@ -142,36 +143,63 @@ private: class TosaSerializationBlockBuilder { public: friend class TosaSerializationOperatorBuilder; - TosaSerializationBlockBuilder(TosaSerializationBasicBlock *_block, - TosaSerializationHandler *_tsh, - mlir::Region *_region) - : block(_block), tsh(_tsh), region(_region) {} + + TosaSerializationBlockBuilder(TosaSerializationBasicBlock* _ser_block, + TosaSerializationRegionBuilder* _region_builder, + mlir::Block* _block) + : ser_block(_ser_block), region_builder(_region_builder), block(_block) {} mlir::LogicalResult - BuildAllOpsInRegion(std::vector &return_values); - TosaSerializationBasicBlock *GetBlock() { return block; } - TosaSerializationHandler *GetTsh() { return tsh; } + BuildAllOpsInBlock(std::vector& return_values); + TosaSerializationBasicBlock* GetBlock() { return ser_block; } + TosaSerializationRegionBuilder* GetRegionBuilder() { return region_builder; } private: TosaSerializationOperator *BuildTosaSerializationOperator( - const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op); - TosaSerializationTensor * + const TosaSerializationOperatorBuilder& op_builder, mlir::Operation& op); + TosaSerializationTensor* BuildTosaSerializationTensor(mlir::Value val, const std::string &name); - TosaSerializationBasicBlock *block; - TosaSerializationHandler *tsh; - mlir::Region *region; + TosaSerializationBasicBlock* ser_block; + TosaSerializationRegionBuilder* region_builder; + mlir::Block* block; std::unordered_map tensor_map; std::unordered_map input_tensor_map; }; +class TosaSerializationRegionBuilder { +public: + friend class TosaSerializationBlockBuilder; + friend class TosaSerializationOperatorBuilder; + + // Constructor + TosaSerializationRegionBuilder(TosaSerializationRegion* _ser_region, + TosaSerializationHandler* _tsh, + mlir::Region* _region) + : ser_region(_ser_region), tsh(_tsh), region(_region) {} + TosaSerializationHandler* GetTsh() { return tsh; } + mlir::LogicalResult BuildAllBlocksInRegion(std::vector& return_values); + + int getNumBlocksInRegion() const { return ser_region->GetBlocks().size(); } + +private: + mlir::Region* region; + TosaSerializationRegion* ser_region; + TosaSerializationHandler* tsh; + std::vector block_builders; +}; + std::string TosaSerializationOperatorBuilder::GetTensorName(mlir::Value val) const { - if (block_builder->tensor_map.find(val) == block_builder->tensor_map.end()) { - llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map"; - assert(0); + // Traverse through each block builder in the region + for (auto curr_block_builder : block_builder->region_builder->block_builders) { + if (curr_block_builder->tensor_map.find(val) != curr_block_builder->tensor_map.end()) { + return curr_block_builder->tensor_map[val]; + } } - return block_builder->tensor_map[val]; + // Didn't find anything + llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map\n"; + assert(0); } // Main template to catch unimplemented translation. @@ -1245,26 +1273,27 @@ TosaSerializationOperator * TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::vector input_names, output_names; - - mlir::Region &then_region = op.getRegion(0); - mlir::Region &else_region = op.getRegion(1); + mlir::Block& then_block = op.getRegion(0).front(); + mlir::Block& else_block = op.getRegion(1).front(); std::vector then_yields, else_yields; - TosaSerializationBasicBlock *then_block = nullptr; - TosaSerializationBasicBlock *else_block = nullptr; + TosaSerializationBasicBlock* ser_then_block = nullptr; + TosaSerializationBasicBlock* ser_else_block = nullptr; // Building then branch block + std::string region_name = block_builder->region_builder->ser_region->GetName(); std::string then_block_name = - "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); - then_block = new TosaSerializationBasicBlock( - then_block_name, std::vector(), + "bb" + std::to_string(block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().size()); + ser_then_block = new TosaSerializationBasicBlock( + then_block_name, region_name, std::vector(), std::vector(), std::vector(), std::vector()); - assert(then_block); - block_builder->GetTsh()->GetBlocks().push_back(then_block); + assert(ser_then_block); + block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().push_back(ser_then_block); TosaSerializationBlockBuilder then_block_builder( - then_block, block_builder->GetTsh(), &then_region); - if (then_block_builder.BuildAllOpsInRegion(then_yields).failed()) { + ser_then_block, block_builder->region_builder, &then_block); + block_builder->region_builder->block_builders.push_back(&then_block_builder); + if (then_block_builder.BuildAllOpsInBlock(then_yields).failed()) { return nullptr; } if (then_yields.size() != op.getNumResults()) { @@ -1275,17 +1304,18 @@ TosaSerializationOperatorBuilder::build( // Building else branch block std::string else_block_name = - "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); - else_block = new TosaSerializationBasicBlock( - else_block_name, std::vector(), + "bb" + std::to_string(block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().size()); + ser_else_block = new TosaSerializationBasicBlock( + else_block_name, region_name, std::vector(), std::vector(), std::vector(), std::vector()); - assert(else_block); - block_builder->GetTsh()->GetBlocks().push_back(else_block); + assert(ser_else_block); + block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().push_back(ser_else_block); TosaSerializationBlockBuilder else_block_builder( - else_block, block_builder->GetTsh(), &else_region); - if (else_block_builder.BuildAllOpsInRegion(else_yields).failed()) { + ser_else_block, block_builder->region_builder, &else_block); + block_builder->region_builder->block_builders.push_back(&else_block_builder); + if (else_block_builder.BuildAllOpsInBlock(else_yields).failed()) { return nullptr; } if (else_yields.size() != op.getNumResults()) { @@ -1294,7 +1324,7 @@ TosaSerializationOperatorBuilder::build( return nullptr; } - TosaCondIfAttribute attribute(then_block->GetName(), else_block->GetName()); + TosaCondIfAttribute attribute(ser_then_block->GetName(), ser_else_block->GetName()); for (size_t i = 0; i < op.getNumOperands(); i++) { std::string input_name = GetTensorName(op.getOperand(i)); @@ -1319,25 +1349,30 @@ TosaSerializationOperatorBuilder::build( mlir::Operation &op) const { std::vector input_names, output_names; - mlir::Region &cond_region = op.getRegion(0); - mlir::Region &body_region = op.getRegion(1); + mlir::Block& cond_block = op.getRegion(0).front(); + mlir::Block& body_block = op.getRegion(1).front(); std::vector cond_yields, body_yields; - TosaSerializationBasicBlock *cond_block = nullptr; - TosaSerializationBasicBlock *body_block = nullptr; + TosaSerializationBasicBlock* ser_cond_block = nullptr; + TosaSerializationBasicBlock* ser_body_block = nullptr; // Building cond branch block std::string cond_block_name = - "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); - cond_block = new TosaSerializationBasicBlock( - cond_block_name, std::vector(), + "bb" + std::to_string(block_builder->region_builder->getNumBlocksInRegion()); + + std::string region_name = block_builder->region_builder->ser_region->GetName(); + + ser_cond_block = new TosaSerializationBasicBlock( + cond_block_name, region_name, std::vector(), std::vector(), std::vector(), std::vector()); - assert(cond_block); - block_builder->GetTsh()->GetBlocks().push_back(cond_block); + assert(ser_cond_block); + block_builder->region_builder->ser_region->GetBlocks().push_back(ser_cond_block); TosaSerializationBlockBuilder cond_block_builder( - cond_block, block_builder->GetTsh(), &cond_region); - if (cond_block_builder.BuildAllOpsInRegion(cond_yields).failed()) { + ser_cond_block, block_builder->region_builder, &cond_block); + block_builder->region_builder->block_builders.push_back(&cond_block_builder); + + if (cond_block_builder.BuildAllOpsInBlock(cond_yields).failed()) { return nullptr; } if (cond_yields.size() != 1) { @@ -1345,19 +1380,21 @@ TosaSerializationOperatorBuilder::build( return nullptr; } + // Building body branch block std::string body_block_name = - "bb" + std::to_string(block_builder->GetTsh()->GetBlocks().size()); - body_block = new TosaSerializationBasicBlock( - body_block_name, std::vector(), + "bb" + std::to_string(block_builder->region_builder->getNumBlocksInRegion()); + ser_body_block = new TosaSerializationBasicBlock( + body_block_name, region_name, std::vector(), std::vector(), std::vector(), std::vector()); - assert(body_block); - block_builder->GetTsh()->GetBlocks().push_back(body_block); + assert(ser_body_block); + block_builder->region_builder->ser_region->GetBlocks().push_back(ser_body_block); TosaSerializationBlockBuilder body_block_builder( - body_block, block_builder->GetTsh(), &body_region); - if (body_block_builder.BuildAllOpsInRegion(body_yields).failed()) { + ser_body_block, block_builder->region_builder, &body_block); + block_builder->region_builder->block_builders.push_back(&body_block_builder); + if (body_block_builder.BuildAllOpsInBlock(body_yields).failed()) { return nullptr; } if (body_yields.size() != op.getNumResults()) { @@ -1366,8 +1403,9 @@ TosaSerializationOperatorBuilder::build( return nullptr; } - TosaWhileLoopAttribute attribute(cond_block->GetName(), - body_block->GetName()); + + TosaWhileLoopAttribute attribute(ser_cond_block->GetName(), + ser_body_block->GetName()); for (size_t i = 0; i < op.getNumOperands(); i++) { std::string input_name = GetTensorName(op.getOperand(i)); @@ -1387,8 +1425,40 @@ TosaSerializationOperatorBuilder::build( } /* End translating TOSA operator */ +mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(std::vector& return_values) { + std::string region_name = ser_region->GetName(); + // this will likely run once for most cases. + for (auto& block : this->region->getBlocks()) { + // TODO: update the block name + TosaSerializationBasicBlock* ser_block = new TosaSerializationBasicBlock( + std::string("main"), region_name, std::vector(), + std::vector(), std::vector(), + std::vector() + ); + + // build the block + TosaSerializationBlockBuilder block_builder(ser_block, this, &block); + // Region Builders need access to block builders + block_builders.push_back(&block_builder); + + if (block_builder.BuildAllOpsInBlock(return_values).failed()) { + return mlir::failure(); + } + + if (return_values.empty()) { + llvm::errs() << "BWarning: graph doesn't have return values\n"; + } + + // Add serialized block to serialized region + ser_region->GetBlocks().push_back(ser_block); + } + + return mlir::success(); +} -mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( + + +mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( std::vector &return_values) { TosaSerializationOperator *ser_operator = nullptr; TosaSerializationTensor *ser_tensor = nullptr; @@ -1398,104 +1468,93 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( static int output_tensor_index = 0; TosaSerializationOperatorBuilder op_builder(this); - for (auto &bb : region->getBlocks()) { - num_blocks_in_region++; - - if (num_blocks_in_region > 1) { - llvm::errs() << "Invalid MLIR: multiple blocks in a region\n"; - return mlir::failure(); - } - - // We always have one block for each region right now - assert(bb.isEntryBlock()); + // Specify block input tensor name + for (auto args : block->getArguments()) { + std::string block_input_name = + "TosaInput_" + std::to_string(input_tensor_index++); + ser_block->GetInputs().push_back(block_input_name); + tensor_map[args] = block_input_name; + input_tensor_map[args] = block_input_name; + } - // Specify block input tensor name - for (auto args : bb.getArguments()) { - std::string block_input_name = - "TosaInput_" + std::to_string(input_tensor_index++); - block->GetInputs().push_back(block_input_name); - tensor_map[args] = block_input_name; - input_tensor_map[args] = block_input_name; - } + // Build tensor_map + for (auto &op : block->getOperations()) { + if (!(llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op))) { + for (uint32_t i = 0; i < op.getNumResults(); i++) { + std::string intermediate_tensor_name = + "layer_" + std::to_string(intermediate_tensor_index++); + tensor_map[op.getResult(i)] = intermediate_tensor_name; + } + } else { + if (llvm::isa(op)) + continue; + // Override return tensor name + for (auto val : op.getOperands()) { + // Workaround to skip mlir::tensor::CastOp before return + mlir::Operation *val_defining_op = val.getDefiningOp(); + if (val_defining_op) { + if (llvm::isa(*val_defining_op)) + val = val_defining_op->getOperand(0); + } - // Build tensor_map - for (auto &op : bb) { - if (!(llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op))) { - for (uint32_t i = 0; i < op.getNumResults(); i++) { - std::string intermediate_tensor_name = - "layer_" + std::to_string(intermediate_tensor_index++); - tensor_map[op.getResult(i)] = intermediate_tensor_name; + // Sanity check. This mlir::Value should be built in map since graph + // is DAG + if (tensor_map.find(val) == tensor_map.end()) { + llvm::errs() << "ERROR: Can't find built mlir::Value key.\n"; + return mlir::failure(); } - } else { - if (llvm::isa(op)) - continue; - // Override return tensor name - for (auto val : op.getOperands()) { - // Workaround to skip mlir::tensor::CastOp before return - mlir::Operation *val_defining_op = val.getDefiningOp(); - if (val_defining_op) { - if (llvm::isa(*val_defining_op)) - val = val_defining_op->getOperand(0); - } - - // Sanity check. This mlir::Value should be built in map since graph - // is DAG - if (tensor_map.find(val) == tensor_map.end()) { - llvm::errs() << "ERROR: Can't find built mlir::Value key.\n"; - return mlir::failure(); - } - - // If returned value is block input, short-circuit the tensor name - // Otherwise, build a new output name and override the origin tensor - // name - if (input_tensor_map.find(val) != input_tensor_map.end()) { - block->GetOutputs().push_back(input_tensor_map[val]); - return_values.push_back(val); - } else { - std::string output_name = - "TosaOutput_" + std::to_string(output_tensor_index++); - tensor_map[val] = output_name; - block->GetOutputs().push_back(output_name); - return_values.push_back(val); - } + + // If returned value is block input, short-circuit the tensor name + // Otherwise, build a new output name and override the origin tensor + // name + if (input_tensor_map.find(val) != input_tensor_map.end()) { + ser_block->GetOutputs().push_back(input_tensor_map[val]); + return_values.push_back(val); + } else { + std::string output_name = + "TosaOutput_" + std::to_string(output_tensor_index++); + tensor_map[val] = output_name; + ser_block->GetOutputs().push_back(output_name); + return_values.push_back(val); } } } + } - // Build tensor - - // The tensor_map is sorted by hashed mlir::Value types. - // For serialization, sort tensors alphabetically by name for a - // deterministic and human-friendly ordering. - std::map tensor_name_sort; - for (auto pair : tensor_map) - tensor_name_sort[pair.second] = pair.first; - - for (auto pair : tensor_name_sort) { - ser_tensor = BuildTosaSerializationTensor(pair.second /* val */, - pair.first /* name */); - if (!ser_tensor) { - llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; - return mlir::failure(); - } - block->GetTensors().push_back(ser_tensor); - } + // Build tensor - // Build operator - for (auto &op : bb) { - if (llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op)) - continue; - ser_operator = BuildTosaSerializationOperator(op_builder, op); - if (!ser_operator) { - llvm::errs() << "ERROR: Failed to build TosaSerializationOperator\n"; - return mlir::failure(); - } - block->GetOperators().push_back(ser_operator); + // The tensor_map is sorted by hashed mlir::Value types. + // For serialization, sort tensors alphabetically by name for a + // deterministic and human-friendly ordering. + std::map tensor_name_sort; + for (auto pair : tensor_map) + tensor_name_sort[pair.second] = pair.first; + + + for (auto pair : tensor_name_sort) { + ser_tensor = BuildTosaSerializationTensor(pair.second /* val */, + pair.first /* name */); + if (!ser_tensor) { + llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; + return mlir::failure(); + } + ser_block->GetTensors().push_back(ser_tensor); + } + + // Build operator + for (auto &op : block->getOperations()) { + if (llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op)) + continue; + ser_operator = BuildTosaSerializationOperator(op_builder, op); + if (!ser_operator) { + llvm::errs() << "ERROR: Failed to build TosaSerializationOperator\n"; + return mlir::failure(); } + ser_block->GetOperators().push_back(ser_operator); } return mlir::success(); @@ -1529,6 +1588,7 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator( // Sanity check the number of inputs/outputs of TOSA dialect matches the // number of TOSA flatbuffer if (op.getNumOperands() != target_operator->GetInputTensorNames().size()) { + llvm::errs() << op << "\n"; llvm::errs() << "WARNING. MLIR operator has " << op.getNumOperands() << " input tensors != Flatbuffer " "operator has " @@ -1551,7 +1611,7 @@ TosaSerializationBlockBuilder::BuildTosaSerializationTensor( mlir::Value val, const std::string &name) { // If tensor already created before, use that tensor directly, create a new // one otherwise - TosaSerializationTensor *ts = block->GetTensorByName(name); + TosaSerializationTensor *ts = ser_block->GetTensorByName(name); if (ts) { return nullptr; } @@ -1574,7 +1634,7 @@ TosaSerializationBlockBuilder::BuildTosaSerializationTensor( mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func, TosaSerializationHandler &tsh) { - TosaSerializationBasicBlock *main_block; + TosaSerializationRegion* ser_main_region; mlir::Region *main_region = func.getCallableRegion(); std::vector main_returns; @@ -1584,21 +1644,21 @@ mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func, return mlir::failure(); } - if (!tsh.GetBlocks().empty()) { - llvm::errs() << "Internal Error: TosaSerializationHandler's block list " + if (!tsh.GetRegions().empty()) { + llvm::errs() << "Internal Error: TosaSerializationHandler's region list " "must be empty\n"; return mlir::failure(); } - main_block = new TosaSerializationBasicBlock( - std::string("main"), std::vector(), - std::vector(), std::vector(), - std::vector()); - assert(main_block); - tsh.GetBlocks().push_back(main_block); + ser_main_region = new TosaSerializationRegion( + std::string("main"), /* region_name */ + std::vector() /* empty serialized block container */ + ); + assert(ser_main_region); + tsh.GetRegions().push_back(ser_main_region); - TosaSerializationBlockBuilder block_builder(main_block, &tsh, main_region); - if (block_builder.BuildAllOpsInRegion(main_returns).failed()) { + TosaSerializationRegionBuilder region_builder(ser_main_region, &tsh, main_region); + if (region_builder.BuildAllBlocksInRegion(main_returns).failed()) { return mlir::failure(); } diff --git a/third_party/serialization_lib b/third_party/serialization_lib index 6388a09..ca7ce0e 160000 --- a/third_party/serialization_lib +++ b/third_party/serialization_lib @@ -1 +1 @@ -Subproject commit 6388a097de4350cc70472921c272074190fd7c93 +Subproject commit ca7ce0e94b3ee7339f31b47baa3a3fb4522243a2 -- cgit v1.2.1