aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-03-16 22:22:41 +0000
committerTai Ly <tai.ly@arm.com>2023-04-26 00:23:38 +0000
commit8ffce6d63372ea93f3e060571137bce11d4735d8 (patch)
tree00b8a75f1cc5f6331818fefa750c08131f801f5c
parent828099da4d3777254688f3a75063a705762064d6 (diff)
downloadtosa_mlir_translator-8ffce6d63372ea93f3e060571137bce11d4735d8.tar.gz
serialize/deserialize while/if ops using regions
This changes serialization and deserialization of while and if ops to use regions instead of blocks. also changed deserialization to preserve original ordering such that: tosa1->deserialize->serialize->tosa2 => tosa1 == tosa2 most of the time. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I539076788bdc466ba1881d955af349e6b4924ed8
-rw-r--r--src/TosaDeserialize.cpp290
-rw-r--r--src/TosaSerialize.cpp270
2 files changed, 356 insertions, 204 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 3d38e1b..495d6f0 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -25,9 +25,9 @@
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tosa_serialization_handler.h"
#include <functional>
-#include <map>
#include <queue>
#include <unordered_map>
+#include <unordered_set>
#include <vector>
// The namespace might be confusing here. We have mlir::tosa:: defined in MLIR
@@ -155,10 +155,12 @@ public:
TosaMlirOperatorBuilder(
mlir::OpBuilder *_op_builder, TosaSerializationBasicBlock *_ser_block,
mlir::Block *_block, mlir::Location _loc,
+ TosaMlirBlockBuilder *_block_builder,
std::unordered_map<std::string, mlir::Value> *_tensor_map,
std::unordered_map<std::string, mlir::RankedTensorType> *_tensor_type_map)
: op_builder(_op_builder), ser_block(_ser_block), block(_block),
- loc(_loc), tensor_map(_tensor_map), tensor_type_map(_tensor_type_map) {}
+ loc(_loc), block_builder(_block_builder), tensor_map(_tensor_map),
+ tensor_type_map(_tensor_type_map) {}
template <Op OPCODE>
std::vector<mlir::Value> build(TosaSerializationOperator *op) const;
@@ -179,6 +181,9 @@ public:
return op_string;
}
+ TosaSerializationHandler *GetTsh() const;
+ TosaMlirRegionBuilder *GetRegionBuilder() const;
+
private:
template <class MLIR_OP>
std::vector<mlir::Value>
@@ -199,6 +204,7 @@ private:
TosaSerializationBasicBlock *ser_block;
mlir::Block *block;
mlir::Location loc;
+ TosaMlirBlockBuilder *block_builder;
std::unordered_map<std::string, mlir::Value> *tensor_map;
std::unordered_map<std::string, mlir::RankedTensorType> *tensor_type_map;
};
@@ -208,7 +214,7 @@ template <Op OPCODE>
std::vector<mlir::Value> TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const
{
llvm::errs() << "ERROR: " << get_string(op) << " translation hasn't been implemented\n";
- return std::vector<mlir::Value>();
+ return {};
}
// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D)
@@ -428,7 +434,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONST>(TosaSerializat
default:
llvm::errs() << "ERROR: " << get_string(op)
<< " contains unsupported element type\n";
- return std::vector<mlir::Value>();
+ return {};
}
mlir::Operation *mlir_op =
op_builder->create<mlir::tosa::ConstOp>(loc, output_type, value_attr);
@@ -1020,25 +1026,6 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CUSTOM>(TosaSerializa
return std::vector<mlir::Value>({mlir_op->getResult(0)});
}
-// TosaSerializationBasicBlock
-// template <>
-// std::vector<mlir::Value>
-// TosaMlirOperatorBuilder::build<COND_IF>(TosaSerializationOperator* op) const
-//{
-// // todo: mlir::tosa::IfOp
-// return {};
-//}
-
-// TosaSerializationBasicBlock
-// template <>
-// std::vector<mlir::Value>
-// TosaMlirOperatorBuilder::build<WHILE_LOOP>(TosaSerializationOperator* op)
-// const
-//{
-// // todo: mlir::tosa::WhileOp
-// return {};
-//}
-
template <>
std::vector<mlir::Value>
TosaMlirOperatorBuilder::build<Op_RFFT2D>(TosaSerializationOperator *op) const {
@@ -1081,17 +1068,28 @@ TosaMlirOperatorBuilder::build<Op_FFT2D>(TosaSerializationOperator *op) const {
class TosaMlirRegionBuilder {
public:
- TosaMlirRegionBuilder(TosaSerializationRegion* _ser_region,
- TosaSerializationHandler* _tsh,
- mlir::Region* _region,
- mlir::OpBuilder* _op_builder,
- mlir::Location _loc)
- : ser_region(_ser_region), tsh(_tsh), region(_region), op_builder(_op_builder), loc(_loc) {}
+ TosaMlirRegionBuilder(TosaSerializationRegion *_ser_region,
+ TosaSerializationHandler *_tsh, mlir::Region *_region,
+ mlir::OpBuilder *_op_builder, mlir::Location _loc,
+ TosaMlirRegionBuilder *_parent_value_scope = nullptr)
+ : ser_region(_ser_region), tsh(_tsh), region(_region),
+ op_builder(_op_builder), loc(_loc) {
+ if (_parent_value_scope) {
+ // inherit parent_value_scope's tensor_map
+ for (auto &kv : _parent_value_scope->GetTensorMap()) {
+ tensor_map.insert(kv);
+ }
+ }
+ }
mlir::LogicalResult BuildAllBlocksInRegion(std::vector<mlir::Value>& return_values);
mlir::OpBuilder* GetOpBuilder() { return op_builder; }
mlir::Location GetLocation() { return loc; }
+ std::unordered_map<std::string, mlir::Value> &GetTensorMap() {
+ return tensor_map;
+ }
+ TosaSerializationHandler *GetTsh() const { return tsh; }
private:
mlir::Region* region;
@@ -1099,7 +1097,7 @@ private:
TosaSerializationHandler* tsh;
mlir::OpBuilder* op_builder;
mlir::Location loc;
- std::vector<TosaMlirBlockBuilder*> block_builders;
+ std::unordered_map<std::string, mlir::Value> tensor_map;
};
class TosaMlirBlockBuilder {
@@ -1115,28 +1113,185 @@ public:
mlir::OpBuilder* GetOpBuilder() { return region_builder->GetOpBuilder(); }
mlir::Location GetLocation() { return region_builder->GetLocation(); }
+ std::unordered_map<std::string, mlir::Value> &GetTensorMap() {
+ return region_builder->GetTensorMap();
+ }
+
+ TosaSerializationHandler *GetTsh() const { return region_builder->GetTsh(); }
+ TosaMlirRegionBuilder *GetRegionBuilder() const { return region_builder; }
private:
TosaSerializationBasicBlock* ser_block;
TosaMlirRegionBuilder* region_builder;
mlir::Block* block;
- std::unordered_map<std::string, mlir::Value> tensor_map;
std::unordered_map<std::string, mlir::RankedTensorType> tensor_type_map;
};
+TosaSerializationHandler *TosaMlirOperatorBuilder::GetTsh() const {
+ return block_builder->GetTsh();
+}
+
+TosaMlirRegionBuilder *TosaMlirOperatorBuilder::GetRegionBuilder() const {
+ return block_builder->GetRegionBuilder();
+}
+
+// build control flow ops:
+
+namespace {
+
+mlir::LogicalResult
+BuildRegion(TosaSerializationRegion *ser_region, TosaSerializationHandler *tsh,
+ mlir::Region *mlir_region, mlir::OpBuilder *op_builder,
+ mlir::Location loc, std::vector<mlir::Value> &return_values,
+ bool isolated_from_above = false,
+ TosaMlirRegionBuilder *parent_region_builder = nullptr) {
+ TosaMlirRegionBuilder *parent_value_scope =
+ isolated_from_above ? nullptr : parent_region_builder;
+ TosaMlirRegionBuilder region_builder(ser_region, tsh, mlir_region, op_builder,
+ loc, parent_value_scope);
+ return region_builder.BuildAllBlocksInRegion(return_values);
+}
+
+} // namespace
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_COND_IF>(
+ TosaSerializationOperator *op) const {
+ mlir::Value cond_val = tensor_map->at(op->GetInputTensorNames().at(0));
+ std::vector<mlir::Value> input_values;
+ for (auto idx = 1u; idx < op->GetInputTensorNames().size(); idx++) {
+ input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx)));
+ }
+ std::vector<mlir::Type> output_types;
+ for (auto &name : op->GetInputTensorNames()) {
+ output_types.push_back(tensor_type_map->at(name));
+ }
+
+ assert(op->GetAttributeType() ==
+ Attribute_CondIfAttribute); // double check attribute type
+ TosaCondIfAttribute *attr =
+ static_cast<TosaCondIfAttribute *>(op->GetAttribute());
+ auto ser_then_region = GetTsh()->GetRegionByName(attr->then_branch());
+ auto ser_else_region = GetTsh()->GetRegionByName(attr->else_branch());
+
+ if (!ser_then_region || !ser_else_region) {
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " region serialization hasn't been implemented\n";
+ return {};
+ }
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::IfOp>(
+ loc, output_types, cond_val, input_values);
+
+ const bool isolated_from_above =
+ mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+ mlir::Region &then_region = mlir_op->getRegion(0);
+ mlir::Region &else_region = mlir_op->getRegion(1);
+
+ auto curr_region_builder = GetRegionBuilder();
+
+ std::vector<mlir::Value> then_returns, else_returns;
+
+ if (BuildRegion(ser_then_region, GetTsh(), &then_region, op_builder, loc,
+ then_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (then_returns.size() != mlir_op->getNumResults()) {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " then_region yield.size() doesn't match cond_if's output size\n";
+ return {};
+ }
+
+ if (BuildRegion(ser_else_region, GetTsh(), &else_region, op_builder, loc,
+ else_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (else_returns.size() != mlir_op->getNumResults()) {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " else_region yield.size() doesn't match cond_if's output size\n";
+ return {};
+ }
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(mlir_op->getResults().begin(),
+ mlir_op->getResults().end());
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_WHILE_LOOP>(
+ TosaSerializationOperator *op) const {
+ std::vector<mlir::Value> input_values;
+ for (auto idx = 0u; idx < op->GetInputTensorNames().size(); idx++) {
+ input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx)));
+ }
+ std::vector<mlir::Type> output_types;
+ for (auto &name : op->GetInputTensorNames()) {
+ output_types.push_back(tensor_type_map->at(name));
+ }
+ assert(op->GetAttributeType() ==
+ Attribute_WhileLoopAttribute); // double check attribute type
+ TosaWhileLoopAttribute *attr =
+ static_cast<TosaWhileLoopAttribute *>(op->GetAttribute());
+ auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_branch());
+ auto ser_body_region = GetTsh()->GetRegionByName(attr->body_branch());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::WhileOp>(loc, output_types, input_values);
+
+ const bool isolated_from_above =
+ mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+
+ mlir::Region &cond_region = mlir_op->getRegion(0);
+ mlir::Region &body_region = mlir_op->getRegion(1);
+
+ auto curr_region_builder = GetRegionBuilder();
+
+ std::vector<mlir::Value> cond_returns, body_returns;
+
+ if (BuildRegion(ser_cond_region, GetTsh(), &cond_region, op_builder, loc,
+ cond_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (cond_returns.size() != 1) {
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " cond_region yield.size() is not 1\n";
+ return {};
+ }
+
+ if (BuildRegion(ser_body_region, GetTsh(), &body_region, op_builder, loc,
+ body_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (body_returns.size() != mlir_op->getNumResults()) {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " body_region yield.size() doesn't match while_loop's output size\n";
+ return {};
+ }
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(mlir_op->getResults().begin(),
+ mlir_op->getResults().end());
+}
+
mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
std::vector<mlir::Value> &return_values) {
block->clear();
auto loc = GetLocation();
auto op_builder = GetOpBuilder();
+ auto &tensor_map = GetTensorMap();
- std::unordered_map<std::string, std::vector<TosaSerializationOperator*>> consumer_map;
- std::unordered_map<std::string, bool> tensor_built;
- std::unordered_map<TosaSerializationOperator*, bool> operator_built;
+ std::unordered_set<TosaSerializationOperator *> operator_built;
std::queue<TosaSerializationOperator*> operator_queue;
TosaMlirOperatorBuilder tosa_op_builder(op_builder, ser_block, block, loc,
- &tensor_map, &tensor_type_map);
+ this, &tensor_map, &tensor_type_map);
for (auto ts : ser_block->GetTensors()) {
mlir::RankedTensorType type;
@@ -1145,28 +1300,19 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
}
const auto& ts_name = ts->GetName();
tensor_type_map[ts_name] = type;
- tensor_built[ts_name] = false;
}
- for (auto op : ser_block->GetOperators()) {
- operator_built[op] = false;
- for (auto ts_name : op->GetInputTensorNames()) {
- consumer_map[ts_name].push_back(op);
- }
- }
-
- // Update operator_queue if a consumer of tensor_name has all of its inputs already built
- auto queue_ready_consumers = [&](const std::string tensor_name) {
- for (auto consumer_op : consumer_map[tensor_name]) {
- // Sanity check operator hasn't been built
- if (operator_built[consumer_op]) {
- llvm::errs() << "ERROR: " << tosa_op_builder.get_string(consumer_op)
- << " is already built before its input is built\n";
- assert(0);
+ // Update operator_queue with operators whose inputs are all built
+ auto queue_ready_operators = [&]() {
+ // note: it is important to queue in order of original operators
+ // because an operator input may be defined in upper value scopes
+ for (auto consumer_op : ser_block->GetOperators()) {
+ if (operator_built.count(consumer_op)) {
+ continue;
}
bool all_inputs_ready = true;
for (const auto& input_name : consumer_op->GetInputTensorNames()) {
- if (!tensor_built[input_name]) {
+ if (!tensor_map.count(input_name)) {
all_inputs_ready = false;
break;
}
@@ -1177,7 +1323,7 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
}
};
- // Initialize tensor_map/tensor_built/operator_queue based on block input arguments
+ // Initialize tensor_map/operator_queue based on block input arguments
for (const std::string& block_input_name : ser_block->GetInputs()) {
auto type = tensor_type_map[block_input_name];
auto input_value = block->addArgument(type, loc);
@@ -1185,29 +1331,21 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
llvm::errs() << "ERROR: block input tensor " << block_input_name << " already exists\n";
return mlir::failure();
}
- tensor_built[block_input_name] = true;
tensor_map[block_input_name] = input_value;
- queue_ready_consumers(block_input_name);
}
- // add all operators with 0 inputs (e.g., constant operators) to
- // operator_queue
- for (auto op : ser_block->GetOperators()) {
- if (op->GetInputTensorNames().empty()) {
- operator_queue.push(op);
- }
- }
+ queue_ready_operators();
while (!operator_queue.empty()) {
TosaSerializationOperator* op = operator_queue.front();
operator_queue.pop();
// skip if operator has been built
- if (operator_built[op]) {
+ if (operator_built.count(op)) {
// this happens when same input appears twice or more in operator, eg, concat(%0, %0)
continue;
}
- operator_built[op] = true;
+ operator_built.insert(op);
std::vector<mlir::Value> output_values;
if (false) {
@@ -1232,22 +1370,23 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
for (size_t i = 0; i < output_values.size(); i++) {
// Sanity check tensor hasn't been built
std::string op_output_name = op->GetOutputTensorNames()[i];
- if (tensor_built[op_output_name]) {
+ if (tensor_map.count(op_output_name)) {
llvm::errs() << "ERROR: tensor " << op_output_name << " is already built\n";
return mlir::failure();
}
tensor_map[op_output_name] = output_values[i];
- tensor_built[op_output_name] = true;
- queue_ready_consumers(op_output_name);
}
+ // look for any more ready consumers
+ queue_ready_operators();
}
// Construct return values
std::vector<mlir::Value> return_operands;
for (const auto& output_name : ser_block->GetOutputs()) {
// Sanity check if terminator mlir::Value is built
- if (!tensor_built[output_name]) {
- llvm::errs() << "ERROR: terminator mlir::Value " << output_name << " is not built\n";
+ if (!tensor_map.count(output_name)) {
+ llvm::errs() << "ERROR: terminator mlir::Value " << output_name
+ << " is not built in block " << ser_block->GetName() << "\n";
return mlir::failure();
}
mlir::Value output_value = tensor_map.at(output_name);
@@ -1267,8 +1406,6 @@ mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion(
for (auto& ser_block : ser_region->GetBlocks()) {
auto& block = region->emplaceBlock();
TosaMlirBlockBuilder block_builder(ser_block, this, &block);
- // Region Builders need access to block builders (?)
- block_builders.push_back(&block_builder);
if (block_builder.BuildAllOpsInBlock(return_values).failed()) {
return mlir::failure();
@@ -1291,12 +1428,6 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func,
return mlir::failure();
}
- if (tsh.GetRegions().size() != 1) {
- llvm::errs() << "Internal Error: TosaSerializationHandler's region list "
- "must contain exactly one region\n";
- return mlir::failure();
- }
-
TosaSerializationRegion* ser_main_region = tsh.GetRegions().front();
auto loc = func.getLoc();
@@ -1305,8 +1436,9 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func,
main_region->takeBody(*main_region); // empty old func body
auto op_builder = mlir::OpBuilder(func.getBody());
- TosaMlirRegionBuilder region_builder(ser_main_region, &tsh, main_region, &op_builder, loc);
- if (region_builder.BuildAllBlocksInRegion(main_returns).failed()) {
+ if (BuildRegion(ser_main_region, &tsh, main_region, &op_builder, loc,
+ main_returns)
+ .failed()) {
return mlir::failure();
}
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 8a95b68..4fd5014 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -123,8 +123,11 @@ public:
TosaSerializationOperatorBuilder(
TosaSerializationBlockBuilder *_block_builder)
: block_builder(_block_builder) {}
+
template <typename T>
TosaSerializationOperator *build(mlir::Operation &op) const;
+ TosaSerializationHandler *GetTsh() const;
+ TosaSerializationRegionBuilder *GetRegionBuilder() const;
private:
std::string GetTensorName(mlir::Value val) const;
@@ -142,8 +145,7 @@ private:
// This builder assumes each region only has only one block
class TosaSerializationBlockBuilder {
public:
- friend class TosaSerializationOperatorBuilder;
-
+ // constructor
TosaSerializationBlockBuilder(TosaSerializationBasicBlock* _ser_block,
TosaSerializationRegionBuilder* _region_builder,
mlir::Block* _block)
@@ -151,8 +153,14 @@ public:
mlir::LogicalResult
BuildAllOpsInBlock(std::vector<mlir::Value>& return_values);
- TosaSerializationBasicBlock* GetBlock() { return ser_block; }
- TosaSerializationRegionBuilder* GetRegionBuilder() { return region_builder; }
+ TosaSerializationBasicBlock *GetBlock() const { return ser_block; }
+ TosaSerializationRegionBuilder *GetRegionBuilder() const {
+ return region_builder;
+ }
+ TosaSerializationHandler *GetTsh() const;
+ std::unordered_map<mlir::Value, std::string> &GetTensorMap() {
+ return tensor_map;
+ }
private:
TosaSerializationOperator *BuildTosaSerializationOperator(
@@ -169,33 +177,54 @@ private:
class TosaSerializationRegionBuilder {
public:
- friend class TosaSerializationBlockBuilder;
- friend class TosaSerializationOperatorBuilder;
-
// Constructor
- TosaSerializationRegionBuilder(TosaSerializationRegion* _ser_region,
- TosaSerializationHandler* _tsh,
- mlir::Region* _region)
- : ser_region(_ser_region), tsh(_tsh), region(_region) {}
- TosaSerializationHandler* GetTsh() { return tsh; }
- mlir::LogicalResult BuildAllBlocksInRegion(std::vector<mlir::Value>& return_values);
-
- int getNumBlocksInRegion() const { return ser_region->GetBlocks().size(); }
+ TosaSerializationRegionBuilder(
+ TosaSerializationRegion *_ser_region, mlir::Region *_region,
+ TosaSerializationRegionBuilder *_parent_value_scope,
+ TosaSerializationHandler *_tsh)
+ : ser_region(_ser_region), region(_region),
+ parent_value_scope(_parent_value_scope), tsh(_tsh) {}
+ TosaSerializationHandler *GetTsh() const { return tsh; }
+ mlir::LogicalResult
+ BuildAllBlocksInRegion(bool is_top, std::vector<mlir::Value> &return_values);
+ TosaSerializationRegionBuilder *GetParentValueScope() const {
+ return parent_value_scope;
+ }
+ std::vector<TosaSerializationBlockBuilder *> &GetBlockBuilders() {
+ return block_builders;
+ }
private:
- mlir::Region* region;
TosaSerializationRegion* ser_region;
+ mlir::Region *region;
+ TosaSerializationRegionBuilder *parent_value_scope;
TosaSerializationHandler* tsh;
std::vector<TosaSerializationBlockBuilder*> block_builders;
};
+TosaSerializationHandler *TosaSerializationOperatorBuilder::GetTsh() const {
+ return block_builder->GetTsh();
+}
+TosaSerializationHandler *TosaSerializationBlockBuilder::GetTsh() const {
+ return region_builder->GetTsh();
+}
+TosaSerializationRegionBuilder *
+TosaSerializationOperatorBuilder::GetRegionBuilder() const {
+ return block_builder->GetRegionBuilder();
+}
+
std::string
TosaSerializationOperatorBuilder::GetTensorName(mlir::Value val) const {
- // Traverse through each block builder in the region
- for (auto curr_block_builder : block_builder->region_builder->block_builders) {
- if (curr_block_builder->tensor_map.find(val) != curr_block_builder->tensor_map.end()) {
- return curr_block_builder->tensor_map[val];
+ auto value_scope = GetRegionBuilder();
+ while (value_scope) {
+ // Traverse through each block builder in the region
+ for (auto curr_block_builder : value_scope->GetBlockBuilders()) {
+ const auto &tensor_map = curr_block_builder->GetTensorMap();
+ if (tensor_map.count(val)) {
+ return tensor_map.at(val);
+ }
}
+ value_scope = value_scope->GetParentValueScope();
}
// Didn't find anything
llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map\n";
@@ -1195,32 +1224,57 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::CustomOp>(
return tyop;
}
+namespace {
+
+// serialize a region and all its blocks, and return region's return values
+TosaSerializationRegion *
+BuildRegion(mlir::Region &region, const std::string region_name,
+ const bool isolated_from_above,
+ TosaSerializationRegionBuilder *curr_region_builder,
+ TosaSerializationHandler *tsh,
+ std::vector<mlir::Value> &return_values, bool is_top = false) {
+ TosaSerializationRegion *ser_region =
+ new TosaSerializationRegion(region_name, {});
+ assert(ser_region);
+ tsh->GetRegions().push_back(ser_region);
+
+ TosaSerializationRegionBuilder *parent_value_scope =
+ isolated_from_above ? nullptr : curr_region_builder;
+
+ TosaSerializationRegionBuilder region_builder(ser_region, &region,
+ parent_value_scope, tsh);
+ if (region_builder.BuildAllBlocksInRegion(is_top, return_values).failed()) {
+ return nullptr;
+ }
+ return ser_region;
+}
+
+static int input_tensor_index = 0;
+static int intermediate_tensor_index = 0;
+static int output_tensor_index = 0;
+
+} // namespace
+
template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
mlir::Operation &op) const {
+ const std::string op_name = op.getName().getStringRef().str();
+ const bool isolated_from_above =
+ op.hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+ auto curr_region_builder = GetRegionBuilder();
std::vector<std::string> input_names, output_names;
- mlir::Block& then_block = op.getRegion(0).front();
- mlir::Block& else_block = op.getRegion(1).front();
std::vector<mlir::Value> then_yields, else_yields;
- TosaSerializationBasicBlock* ser_then_block = nullptr;
- TosaSerializationBasicBlock* ser_else_block = nullptr;
-
- // Building then branch block
- std::string region_name = block_builder->region_builder->ser_region->GetName();
- std::string then_block_name =
- "bb" + std::to_string(block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().size());
- ser_then_block = new TosaSerializationBasicBlock(
- then_block_name, region_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(ser_then_block);
- block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().push_back(ser_then_block);
-
- TosaSerializationBlockBuilder then_block_builder(
- ser_then_block, block_builder->region_builder, &then_block);
- block_builder->region_builder->block_builders.push_back(&then_block_builder);
- if (then_block_builder.BuildAllOpsInBlock(then_yields).failed()) {
+ auto tsh = GetTsh();
+
+ mlir::Region &then_region = op.getRegion(0);
+ mlir::Region &else_region = op.getRegion(1);
+
+ const std::string then_region_name = op_name + "_then_region";
+ TosaSerializationRegion *ser_then_region =
+ BuildRegion(then_region, then_region_name, isolated_from_above,
+ curr_region_builder, tsh, then_yields);
+ if (!ser_then_region) {
return nullptr;
}
if (then_yields.size() != op.getNumResults()) {
@@ -1229,20 +1283,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
return nullptr;
}
- // Building else branch block
- std::string else_block_name =
- "bb" + std::to_string(block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().size());
- ser_else_block = new TosaSerializationBasicBlock(
- else_block_name, region_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(ser_else_block);
- block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().push_back(ser_else_block);
-
- TosaSerializationBlockBuilder else_block_builder(
- ser_else_block, block_builder->region_builder, &else_block);
- block_builder->region_builder->block_builders.push_back(&else_block_builder);
- if (else_block_builder.BuildAllOpsInBlock(else_yields).failed()) {
+ const std::string else_region_name = op_name + "_else_region";
+ TosaSerializationRegion *ser_else_region =
+ BuildRegion(else_region, else_region_name, isolated_from_above,
+ curr_region_builder, tsh, else_yields);
+ if (!ser_else_region) {
return nullptr;
}
if (else_yields.size() != op.getNumResults()) {
@@ -1251,7 +1296,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
return nullptr;
}
- TosaCondIfAttribute attribute(ser_then_block->GetName(), ser_else_block->GetName());
+ TosaCondIfAttribute attribute(then_region_name, else_region_name);
for (size_t i = 0; i < op.getNumOperands(); i++) {
std::string input_name = GetTensorName(op.getOperand(i));
@@ -1263,9 +1308,9 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
output_names.push_back(output_name);
}
- TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_COND_IF, Attribute_CondIfAttribute, &attribute,
- input_names, output_names);
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_COND_IF, Attribute_CondIfAttribute,
+ &attribute, input_names, output_names);
return tyop;
}
@@ -1274,32 +1319,22 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
mlir::Operation &op) const {
+ const std::string op_name = op.getName().getStringRef().str();
+ const bool isolated_from_above =
+ op.hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+ auto curr_region_builder = GetRegionBuilder();
std::vector<std::string> input_names, output_names;
+ auto tsh = GetTsh();
- mlir::Block& cond_block = op.getRegion(0).front();
- mlir::Block& body_block = op.getRegion(1).front();
+ mlir::Region &cond_region = op.getRegion(0);
+ mlir::Region &body_region = op.getRegion(1);
std::vector<mlir::Value> cond_yields, body_yields;
- TosaSerializationBasicBlock* ser_cond_block = nullptr;
- TosaSerializationBasicBlock* ser_body_block = nullptr;
-
- // Building cond branch block
- std::string cond_block_name =
- "bb" + std::to_string(block_builder->region_builder->getNumBlocksInRegion());
-
- std::string region_name = block_builder->region_builder->ser_region->GetName();
- ser_cond_block = new TosaSerializationBasicBlock(
- cond_block_name, region_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(ser_cond_block);
- block_builder->region_builder->ser_region->GetBlocks().push_back(ser_cond_block);
-
- TosaSerializationBlockBuilder cond_block_builder(
- ser_cond_block, block_builder->region_builder, &cond_block);
- block_builder->region_builder->block_builders.push_back(&cond_block_builder);
-
- if (cond_block_builder.BuildAllOpsInBlock(cond_yields).failed()) {
+ const std::string cond_region_name = op_name + "_cond_region";
+ TosaSerializationRegion *ser_cond_region =
+ BuildRegion(cond_region, cond_region_name, isolated_from_above,
+ curr_region_builder, tsh, cond_yields);
+ if (!ser_cond_region) {
return nullptr;
}
if (cond_yields.size() != 1) {
@@ -1307,21 +1342,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
return nullptr;
}
-
- // Building body branch block
- std::string body_block_name =
- "bb" + std::to_string(block_builder->region_builder->getNumBlocksInRegion());
- ser_body_block = new TosaSerializationBasicBlock(
- body_block_name, region_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(ser_body_block);
- block_builder->region_builder->ser_region->GetBlocks().push_back(ser_body_block);
-
- TosaSerializationBlockBuilder body_block_builder(
- ser_body_block, block_builder->region_builder, &body_block);
- block_builder->region_builder->block_builders.push_back(&body_block_builder);
- if (body_block_builder.BuildAllOpsInBlock(body_yields).failed()) {
+ const std::string body_region_name = op_name + "_body_region";
+ TosaSerializationRegion *ser_body_region =
+ BuildRegion(body_region, body_region_name, isolated_from_above,
+ curr_region_builder, tsh, body_yields);
+ if (!ser_body_region) {
return nullptr;
}
if (body_yields.size() != op.getNumResults()) {
@@ -1330,9 +1355,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
return nullptr;
}
-
- TosaWhileLoopAttribute attribute(ser_cond_block->GetName(),
- ser_body_block->GetName());
+ TosaWhileLoopAttribute attribute(cond_region_name, body_region_name);
for (size_t i = 0; i < op.getNumOperands(); i++) {
std::string input_name = GetTensorName(op.getOperand(i));
@@ -1344,9 +1367,9 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
output_names.push_back(output_name);
}
- TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute,
- input_names, output_names);
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_WHILE_LOOP, Attribute_WhileLoopAttribute,
+ &attribute, input_names, output_names);
return tyop;
}
@@ -1390,16 +1413,20 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FFT2dOp>(
}
/* End translating TOSA operator */
-mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(std::vector<mlir::Value>& return_values) {
+mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(
+ bool is_top, std::vector<mlir::Value> &return_values) {
std::string region_name = ser_region->GetName();
- // this will likely run once for most cases.
+ int block_index = 0;
for (auto& block : this->region->getBlocks()) {
- // TODO: update the block name
- TosaSerializationBasicBlock* ser_block = new TosaSerializationBasicBlock(
- std::string("main"), region_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>()
- );
+ // must name first block of top region "main"
+ const std::string block_name =
+ (is_top && block_index == 0)
+ ? "main"
+ : (region_name + "_bb" + std::to_string(block_index++));
+ TosaSerializationBasicBlock *ser_block = new TosaSerializationBasicBlock(
+ block_name, region_name, std::vector<TosaSerializationOperator *>(),
+ std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
+ std::vector<std::string>());
// build the block
TosaSerializationBlockBuilder block_builder(ser_block, this, &block);
@@ -1421,16 +1448,11 @@ mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(std::
return mlir::success();
}
-
-
mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
std::vector<mlir::Value> &return_values) {
TosaSerializationOperator *ser_operator = nullptr;
TosaSerializationTensor *ser_tensor = nullptr;
size_t num_blocks_in_region = 0;
- static int input_tensor_index = 0;
- static int intermediate_tensor_index = 0;
- static int output_tensor_index = 0;
TosaSerializationOperatorBuilder op_builder(this);
// Specify block input tensor name
@@ -1528,7 +1550,6 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
TosaSerializationOperator *
TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) {
- std::string full_op_name = op.getName().getStringRef().str();
TosaSerializationOperator *target_operator = nullptr;
if (false) {
@@ -1600,8 +1621,6 @@ TosaSerializationBlockBuilder::BuildTosaSerializationTensor(
mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func,
TosaSerializationHandler &tsh) {
- TosaSerializationRegion* ser_main_region;
-
mlir::Region *main_region = func.getCallableRegion();
std::vector<mlir::Value> main_returns;
@@ -1616,15 +1635,16 @@ mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func,
return mlir::failure();
}
- ser_main_region = new TosaSerializationRegion(
- std::string("main"), /* region_name */
- std::vector<TosaSerializationBasicBlock*>() /* empty serialized block container */
- );
- assert(ser_main_region);
- tsh.GetRegions().push_back(ser_main_region);
+ // reset static counters
+ input_tensor_index = 0;
+ intermediate_tensor_index = 0;
+ output_tensor_index = 0;
- TosaSerializationRegionBuilder region_builder(ser_main_region, &tsh, main_region);
- if (region_builder.BuildAllBlocksInRegion(main_returns).failed()) {
+ TosaSerializationRegion *ser_main_region =
+ BuildRegion(*main_region, "main", /* isolated_from_above = */ true,
+ /* parent_value_scope = */ nullptr, &tsh, main_returns,
+ /* is_top = */ true);
+ if (!ser_main_region) {
return mlir::failure();
}