diff options
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r-- | src/TosaDeserialize.cpp | 290 |
1 files changed, 211 insertions, 79 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 3d38e1b..495d6f0 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -25,9 +25,9 @@ #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tosa_serialization_handler.h" #include <functional> -#include <map> #include <queue> #include <unordered_map> +#include <unordered_set> #include <vector> // The namespace might be confusing here. We have mlir::tosa:: defined in MLIR @@ -155,10 +155,12 @@ public: TosaMlirOperatorBuilder( mlir::OpBuilder *_op_builder, TosaSerializationBasicBlock *_ser_block, mlir::Block *_block, mlir::Location _loc, + TosaMlirBlockBuilder *_block_builder, std::unordered_map<std::string, mlir::Value> *_tensor_map, std::unordered_map<std::string, mlir::RankedTensorType> *_tensor_type_map) : op_builder(_op_builder), ser_block(_ser_block), block(_block), - loc(_loc), tensor_map(_tensor_map), tensor_type_map(_tensor_type_map) {} + loc(_loc), block_builder(_block_builder), tensor_map(_tensor_map), + tensor_type_map(_tensor_type_map) {} template <Op OPCODE> std::vector<mlir::Value> build(TosaSerializationOperator *op) const; @@ -179,6 +181,9 @@ public: return op_string; } + TosaSerializationHandler *GetTsh() const; + TosaMlirRegionBuilder *GetRegionBuilder() const; + private: template <class MLIR_OP> std::vector<mlir::Value> @@ -199,6 +204,7 @@ private: TosaSerializationBasicBlock *ser_block; mlir::Block *block; mlir::Location loc; + TosaMlirBlockBuilder *block_builder; std::unordered_map<std::string, mlir::Value> *tensor_map; std::unordered_map<std::string, mlir::RankedTensorType> *tensor_type_map; }; @@ -208,7 +214,7 @@ template <Op OPCODE> std::vector<mlir::Value> TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const { llvm::errs() << "ERROR: " << get_string(op) << " translation hasn't been implemented\n"; - return std::vector<mlir::Value>(); + return {}; } // BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D) @@ -428,7 +434,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONST>(TosaSerializat default: llvm::errs() << "ERROR: " << get_string(op) << " contains unsupported element type\n"; - return std::vector<mlir::Value>(); + return {}; } mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ConstOp>(loc, output_type, value_attr); @@ -1020,25 +1026,6 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CUSTOM>(TosaSerializa return std::vector<mlir::Value>({mlir_op->getResult(0)}); } -// TosaSerializationBasicBlock -// template <> -// std::vector<mlir::Value> -// TosaMlirOperatorBuilder::build<COND_IF>(TosaSerializationOperator* op) const -//{ -// // todo: mlir::tosa::IfOp -// return {}; -//} - -// TosaSerializationBasicBlock -// template <> -// std::vector<mlir::Value> -// TosaMlirOperatorBuilder::build<WHILE_LOOP>(TosaSerializationOperator* op) -// const -//{ -// // todo: mlir::tosa::WhileOp -// return {}; -//} - template <> std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RFFT2D>(TosaSerializationOperator *op) const { @@ -1081,17 +1068,28 @@ TosaMlirOperatorBuilder::build<Op_FFT2D>(TosaSerializationOperator *op) const { class TosaMlirRegionBuilder { public: - TosaMlirRegionBuilder(TosaSerializationRegion* _ser_region, - TosaSerializationHandler* _tsh, - mlir::Region* _region, - mlir::OpBuilder* _op_builder, - mlir::Location _loc) - : ser_region(_ser_region), tsh(_tsh), region(_region), op_builder(_op_builder), loc(_loc) {} + TosaMlirRegionBuilder(TosaSerializationRegion *_ser_region, + TosaSerializationHandler *_tsh, mlir::Region *_region, + mlir::OpBuilder *_op_builder, mlir::Location _loc, + TosaMlirRegionBuilder *_parent_value_scope = nullptr) + : ser_region(_ser_region), tsh(_tsh), region(_region), + op_builder(_op_builder), loc(_loc) { + if (_parent_value_scope) { + // inherit parent_value_scope's tensor_map + for (auto &kv : _parent_value_scope->GetTensorMap()) { + tensor_map.insert(kv); + } + } + } mlir::LogicalResult BuildAllBlocksInRegion(std::vector<mlir::Value>& return_values); mlir::OpBuilder* GetOpBuilder() { return op_builder; } mlir::Location GetLocation() { return loc; } + std::unordered_map<std::string, mlir::Value> &GetTensorMap() { + return tensor_map; + } + TosaSerializationHandler *GetTsh() const { return tsh; } private: mlir::Region* region; @@ -1099,7 +1097,7 @@ private: TosaSerializationHandler* tsh; mlir::OpBuilder* op_builder; mlir::Location loc; - std::vector<TosaMlirBlockBuilder*> block_builders; + std::unordered_map<std::string, mlir::Value> tensor_map; }; class TosaMlirBlockBuilder { @@ -1115,28 +1113,185 @@ public: mlir::OpBuilder* GetOpBuilder() { return region_builder->GetOpBuilder(); } mlir::Location GetLocation() { return region_builder->GetLocation(); } + std::unordered_map<std::string, mlir::Value> &GetTensorMap() { + return region_builder->GetTensorMap(); + } + + TosaSerializationHandler *GetTsh() const { return region_builder->GetTsh(); } + TosaMlirRegionBuilder *GetRegionBuilder() const { return region_builder; } private: TosaSerializationBasicBlock* ser_block; TosaMlirRegionBuilder* region_builder; mlir::Block* block; - std::unordered_map<std::string, mlir::Value> tensor_map; std::unordered_map<std::string, mlir::RankedTensorType> tensor_type_map; }; +TosaSerializationHandler *TosaMlirOperatorBuilder::GetTsh() const { + return block_builder->GetTsh(); +} + +TosaMlirRegionBuilder *TosaMlirOperatorBuilder::GetRegionBuilder() const { + return block_builder->GetRegionBuilder(); +} + +// build control flow ops: + +namespace { + +mlir::LogicalResult +BuildRegion(TosaSerializationRegion *ser_region, TosaSerializationHandler *tsh, + mlir::Region *mlir_region, mlir::OpBuilder *op_builder, + mlir::Location loc, std::vector<mlir::Value> &return_values, + bool isolated_from_above = false, + TosaMlirRegionBuilder *parent_region_builder = nullptr) { + TosaMlirRegionBuilder *parent_value_scope = + isolated_from_above ? nullptr : parent_region_builder; + TosaMlirRegionBuilder region_builder(ser_region, tsh, mlir_region, op_builder, + loc, parent_value_scope); + return region_builder.BuildAllBlocksInRegion(return_values); +} + +} // namespace + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_COND_IF>( + TosaSerializationOperator *op) const { + mlir::Value cond_val = tensor_map->at(op->GetInputTensorNames().at(0)); + std::vector<mlir::Value> input_values; + for (auto idx = 1u; idx < op->GetInputTensorNames().size(); idx++) { + input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx))); + } + std::vector<mlir::Type> output_types; + for (auto &name : op->GetInputTensorNames()) { + output_types.push_back(tensor_type_map->at(name)); + } + + assert(op->GetAttributeType() == + Attribute_CondIfAttribute); // double check attribute type + TosaCondIfAttribute *attr = + static_cast<TosaCondIfAttribute *>(op->GetAttribute()); + auto ser_then_region = GetTsh()->GetRegionByName(attr->then_branch()); + auto ser_else_region = GetTsh()->GetRegionByName(attr->else_branch()); + + if (!ser_then_region || !ser_else_region) { + llvm::errs() << "ERROR: " << get_string(op) + << " region serialization hasn't been implemented\n"; + return {}; + } + + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::IfOp>( + loc, output_types, cond_val, input_values); + + const bool isolated_from_above = + mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>(); + mlir::Region &then_region = mlir_op->getRegion(0); + mlir::Region &else_region = mlir_op->getRegion(1); + + auto curr_region_builder = GetRegionBuilder(); + + std::vector<mlir::Value> then_returns, else_returns; + + if (BuildRegion(ser_then_region, GetTsh(), &then_region, op_builder, loc, + then_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (then_returns.size() != mlir_op->getNumResults()) { + llvm::errs() + << "ERROR: " << get_string(op) + << " then_region yield.size() doesn't match cond_if's output size\n"; + return {}; + } + + if (BuildRegion(ser_else_region, GetTsh(), &else_region, op_builder, loc, + else_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (else_returns.size() != mlir_op->getNumResults()) { + llvm::errs() + << "ERROR: " << get_string(op) + << " else_region yield.size() doesn't match cond_if's output size\n"; + return {}; + } + + block->push_back(mlir_op); + return std::vector<mlir::Value>(mlir_op->getResults().begin(), + mlir_op->getResults().end()); +} + +template <> +std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_WHILE_LOOP>( + TosaSerializationOperator *op) const { + std::vector<mlir::Value> input_values; + for (auto idx = 0u; idx < op->GetInputTensorNames().size(); idx++) { + input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx))); + } + std::vector<mlir::Type> output_types; + for (auto &name : op->GetInputTensorNames()) { + output_types.push_back(tensor_type_map->at(name)); + } + assert(op->GetAttributeType() == + Attribute_WhileLoopAttribute); // double check attribute type + TosaWhileLoopAttribute *attr = + static_cast<TosaWhileLoopAttribute *>(op->GetAttribute()); + auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_branch()); + auto ser_body_region = GetTsh()->GetRegionByName(attr->body_branch()); + + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::WhileOp>(loc, output_types, input_values); + + const bool isolated_from_above = + mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>(); + + mlir::Region &cond_region = mlir_op->getRegion(0); + mlir::Region &body_region = mlir_op->getRegion(1); + + auto curr_region_builder = GetRegionBuilder(); + + std::vector<mlir::Value> cond_returns, body_returns; + + if (BuildRegion(ser_cond_region, GetTsh(), &cond_region, op_builder, loc, + cond_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (cond_returns.size() != 1) { + llvm::errs() << "ERROR: " << get_string(op) + << " cond_region yield.size() is not 1\n"; + return {}; + } + + if (BuildRegion(ser_body_region, GetTsh(), &body_region, op_builder, loc, + body_returns, isolated_from_above, curr_region_builder) + .failed()) { + return {}; + } + if (body_returns.size() != mlir_op->getNumResults()) { + llvm::errs() + << "ERROR: " << get_string(op) + << " body_region yield.size() doesn't match while_loop's output size\n"; + return {}; + } + + block->push_back(mlir_op); + return std::vector<mlir::Value>(mlir_op->getResults().begin(), + mlir_op->getResults().end()); +} + mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( std::vector<mlir::Value> &return_values) { block->clear(); auto loc = GetLocation(); auto op_builder = GetOpBuilder(); + auto &tensor_map = GetTensorMap(); - std::unordered_map<std::string, std::vector<TosaSerializationOperator*>> consumer_map; - std::unordered_map<std::string, bool> tensor_built; - std::unordered_map<TosaSerializationOperator*, bool> operator_built; + std::unordered_set<TosaSerializationOperator *> operator_built; std::queue<TosaSerializationOperator*> operator_queue; TosaMlirOperatorBuilder tosa_op_builder(op_builder, ser_block, block, loc, - &tensor_map, &tensor_type_map); + this, &tensor_map, &tensor_type_map); for (auto ts : ser_block->GetTensors()) { mlir::RankedTensorType type; @@ -1145,28 +1300,19 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( } const auto& ts_name = ts->GetName(); tensor_type_map[ts_name] = type; - tensor_built[ts_name] = false; } - for (auto op : ser_block->GetOperators()) { - operator_built[op] = false; - for (auto ts_name : op->GetInputTensorNames()) { - consumer_map[ts_name].push_back(op); - } - } - - // Update operator_queue if a consumer of tensor_name has all of its inputs already built - auto queue_ready_consumers = [&](const std::string tensor_name) { - for (auto consumer_op : consumer_map[tensor_name]) { - // Sanity check operator hasn't been built - if (operator_built[consumer_op]) { - llvm::errs() << "ERROR: " << tosa_op_builder.get_string(consumer_op) - << " is already built before its input is built\n"; - assert(0); + // Update operator_queue with operators whose inputs are all built + auto queue_ready_operators = [&]() { + // note: it is important to queue in order of original operators + // because an operator input may be defined in upper value scopes + for (auto consumer_op : ser_block->GetOperators()) { + if (operator_built.count(consumer_op)) { + continue; } bool all_inputs_ready = true; for (const auto& input_name : consumer_op->GetInputTensorNames()) { - if (!tensor_built[input_name]) { + if (!tensor_map.count(input_name)) { all_inputs_ready = false; break; } @@ -1177,7 +1323,7 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( } }; - // Initialize tensor_map/tensor_built/operator_queue based on block input arguments + // Initialize tensor_map/operator_queue based on block input arguments for (const std::string& block_input_name : ser_block->GetInputs()) { auto type = tensor_type_map[block_input_name]; auto input_value = block->addArgument(type, loc); @@ -1185,29 +1331,21 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( llvm::errs() << "ERROR: block input tensor " << block_input_name << " already exists\n"; return mlir::failure(); } - tensor_built[block_input_name] = true; tensor_map[block_input_name] = input_value; - queue_ready_consumers(block_input_name); } - // add all operators with 0 inputs (e.g., constant operators) to - // operator_queue - for (auto op : ser_block->GetOperators()) { - if (op->GetInputTensorNames().empty()) { - operator_queue.push(op); - } - } + queue_ready_operators(); while (!operator_queue.empty()) { TosaSerializationOperator* op = operator_queue.front(); operator_queue.pop(); // skip if operator has been built - if (operator_built[op]) { + if (operator_built.count(op)) { // this happens when same input appears twice or more in operator, eg, concat(%0, %0) continue; } - operator_built[op] = true; + operator_built.insert(op); std::vector<mlir::Value> output_values; if (false) { @@ -1232,22 +1370,23 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( for (size_t i = 0; i < output_values.size(); i++) { // Sanity check tensor hasn't been built std::string op_output_name = op->GetOutputTensorNames()[i]; - if (tensor_built[op_output_name]) { + if (tensor_map.count(op_output_name)) { llvm::errs() << "ERROR: tensor " << op_output_name << " is already built\n"; return mlir::failure(); } tensor_map[op_output_name] = output_values[i]; - tensor_built[op_output_name] = true; - queue_ready_consumers(op_output_name); } + // look for any more ready consumers + queue_ready_operators(); } // Construct return values std::vector<mlir::Value> return_operands; for (const auto& output_name : ser_block->GetOutputs()) { // Sanity check if terminator mlir::Value is built - if (!tensor_built[output_name]) { - llvm::errs() << "ERROR: terminator mlir::Value " << output_name << " is not built\n"; + if (!tensor_map.count(output_name)) { + llvm::errs() << "ERROR: terminator mlir::Value " << output_name + << " is not built in block " << ser_block->GetName() << "\n"; return mlir::failure(); } mlir::Value output_value = tensor_map.at(output_name); @@ -1267,8 +1406,6 @@ mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion( for (auto& ser_block : ser_region->GetBlocks()) { auto& block = region->emplaceBlock(); TosaMlirBlockBuilder block_builder(ser_block, this, &block); - // Region Builders need access to block builders (?) - block_builders.push_back(&block_builder); if (block_builder.BuildAllOpsInBlock(return_values).failed()) { return mlir::failure(); @@ -1291,12 +1428,6 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func, return mlir::failure(); } - if (tsh.GetRegions().size() != 1) { - llvm::errs() << "Internal Error: TosaSerializationHandler's region list " - "must contain exactly one region\n"; - return mlir::failure(); - } - TosaSerializationRegion* ser_main_region = tsh.GetRegions().front(); auto loc = func.getLoc(); @@ -1305,8 +1436,9 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func, main_region->takeBody(*main_region); // empty old func body auto op_builder = mlir::OpBuilder(func.getBody()); - TosaMlirRegionBuilder region_builder(ser_main_region, &tsh, main_region, &op_builder, loc); - if (region_builder.BuildAllBlocksInRegion(main_returns).failed()) { + if (BuildRegion(ser_main_region, &tsh, main_region, &op_builder, loc, + main_returns) + .failed()) { return mlir::failure(); } |