aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-03-16 22:22:41 +0000
committerTai Ly <tai.ly@arm.com>2023-04-26 00:23:38 +0000
commit8ffce6d63372ea93f3e060571137bce11d4735d8 (patch)
tree00b8a75f1cc5f6331818fefa750c08131f801f5c /src/TosaSerialize.cpp
parent828099da4d3777254688f3a75063a705762064d6 (diff)
downloadtosa_mlir_translator-8ffce6d63372ea93f3e060571137bce11d4735d8.tar.gz
serialize/deserialize while/if ops using regions
This changes serialization and deserialization of while and if ops to use regions instead of blocks. also changed deserialization to preserve original ordering such that: tosa1->deserialize->serialize->tosa2 => tosa1 == tosa2 most of the time. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I539076788bdc466ba1881d955af349e6b4924ed8
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp270
1 files changed, 145 insertions, 125 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 8a95b68..4fd5014 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -123,8 +123,11 @@ public:
TosaSerializationOperatorBuilder(
TosaSerializationBlockBuilder *_block_builder)
: block_builder(_block_builder) {}
+
template <typename T>
TosaSerializationOperator *build(mlir::Operation &op) const;
+ TosaSerializationHandler *GetTsh() const;
+ TosaSerializationRegionBuilder *GetRegionBuilder() const;
private:
std::string GetTensorName(mlir::Value val) const;
@@ -142,8 +145,7 @@ private:
// This builder assumes each region only has only one block
class TosaSerializationBlockBuilder {
public:
- friend class TosaSerializationOperatorBuilder;
-
+ // constructor
TosaSerializationBlockBuilder(TosaSerializationBasicBlock* _ser_block,
TosaSerializationRegionBuilder* _region_builder,
mlir::Block* _block)
@@ -151,8 +153,14 @@ public:
mlir::LogicalResult
BuildAllOpsInBlock(std::vector<mlir::Value>& return_values);
- TosaSerializationBasicBlock* GetBlock() { return ser_block; }
- TosaSerializationRegionBuilder* GetRegionBuilder() { return region_builder; }
+ TosaSerializationBasicBlock *GetBlock() const { return ser_block; }
+ TosaSerializationRegionBuilder *GetRegionBuilder() const {
+ return region_builder;
+ }
+ TosaSerializationHandler *GetTsh() const;
+ std::unordered_map<mlir::Value, std::string> &GetTensorMap() {
+ return tensor_map;
+ }
private:
TosaSerializationOperator *BuildTosaSerializationOperator(
@@ -169,33 +177,54 @@ private:
class TosaSerializationRegionBuilder {
public:
- friend class TosaSerializationBlockBuilder;
- friend class TosaSerializationOperatorBuilder;
-
// Constructor
- TosaSerializationRegionBuilder(TosaSerializationRegion* _ser_region,
- TosaSerializationHandler* _tsh,
- mlir::Region* _region)
- : ser_region(_ser_region), tsh(_tsh), region(_region) {}
- TosaSerializationHandler* GetTsh() { return tsh; }
- mlir::LogicalResult BuildAllBlocksInRegion(std::vector<mlir::Value>& return_values);
-
- int getNumBlocksInRegion() const { return ser_region->GetBlocks().size(); }
+ TosaSerializationRegionBuilder(
+ TosaSerializationRegion *_ser_region, mlir::Region *_region,
+ TosaSerializationRegionBuilder *_parent_value_scope,
+ TosaSerializationHandler *_tsh)
+ : ser_region(_ser_region), region(_region),
+ parent_value_scope(_parent_value_scope), tsh(_tsh) {}
+ TosaSerializationHandler *GetTsh() const { return tsh; }
+ mlir::LogicalResult
+ BuildAllBlocksInRegion(bool is_top, std::vector<mlir::Value> &return_values);
+ TosaSerializationRegionBuilder *GetParentValueScope() const {
+ return parent_value_scope;
+ }
+ std::vector<TosaSerializationBlockBuilder *> &GetBlockBuilders() {
+ return block_builders;
+ }
private:
- mlir::Region* region;
TosaSerializationRegion* ser_region;
+ mlir::Region *region;
+ TosaSerializationRegionBuilder *parent_value_scope;
TosaSerializationHandler* tsh;
std::vector<TosaSerializationBlockBuilder*> block_builders;
};
+TosaSerializationHandler *TosaSerializationOperatorBuilder::GetTsh() const {
+ return block_builder->GetTsh();
+}
+TosaSerializationHandler *TosaSerializationBlockBuilder::GetTsh() const {
+ return region_builder->GetTsh();
+}
+TosaSerializationRegionBuilder *
+TosaSerializationOperatorBuilder::GetRegionBuilder() const {
+ return block_builder->GetRegionBuilder();
+}
+
std::string
TosaSerializationOperatorBuilder::GetTensorName(mlir::Value val) const {
- // Traverse through each block builder in the region
- for (auto curr_block_builder : block_builder->region_builder->block_builders) {
- if (curr_block_builder->tensor_map.find(val) != curr_block_builder->tensor_map.end()) {
- return curr_block_builder->tensor_map[val];
+ auto value_scope = GetRegionBuilder();
+ while (value_scope) {
+ // Traverse through each block builder in the region
+ for (auto curr_block_builder : value_scope->GetBlockBuilders()) {
+ const auto &tensor_map = curr_block_builder->GetTensorMap();
+ if (tensor_map.count(val)) {
+ return tensor_map.at(val);
+ }
}
+ value_scope = value_scope->GetParentValueScope();
}
// Didn't find anything
llvm::errs() << "ERROR: Failed to get mlir::Value from tensor_map\n";
@@ -1195,32 +1224,57 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::CustomOp>(
return tyop;
}
+namespace {
+
+// serialize a region and all its blocks, and return region's return values
+TosaSerializationRegion *
+BuildRegion(mlir::Region &region, const std::string region_name,
+ const bool isolated_from_above,
+ TosaSerializationRegionBuilder *curr_region_builder,
+ TosaSerializationHandler *tsh,
+ std::vector<mlir::Value> &return_values, bool is_top = false) {
+ TosaSerializationRegion *ser_region =
+ new TosaSerializationRegion(region_name, {});
+ assert(ser_region);
+ tsh->GetRegions().push_back(ser_region);
+
+ TosaSerializationRegionBuilder *parent_value_scope =
+ isolated_from_above ? nullptr : curr_region_builder;
+
+ TosaSerializationRegionBuilder region_builder(ser_region, &region,
+ parent_value_scope, tsh);
+ if (region_builder.BuildAllBlocksInRegion(is_top, return_values).failed()) {
+ return nullptr;
+ }
+ return ser_region;
+}
+
+static int input_tensor_index = 0;
+static int intermediate_tensor_index = 0;
+static int output_tensor_index = 0;
+
+} // namespace
+
template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
mlir::Operation &op) const {
+ const std::string op_name = op.getName().getStringRef().str();
+ const bool isolated_from_above =
+ op.hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+ auto curr_region_builder = GetRegionBuilder();
std::vector<std::string> input_names, output_names;
- mlir::Block& then_block = op.getRegion(0).front();
- mlir::Block& else_block = op.getRegion(1).front();
std::vector<mlir::Value> then_yields, else_yields;
- TosaSerializationBasicBlock* ser_then_block = nullptr;
- TosaSerializationBasicBlock* ser_else_block = nullptr;
-
- // Building then branch block
- std::string region_name = block_builder->region_builder->ser_region->GetName();
- std::string then_block_name =
- "bb" + std::to_string(block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().size());
- ser_then_block = new TosaSerializationBasicBlock(
- then_block_name, region_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(ser_then_block);
- block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().push_back(ser_then_block);
-
- TosaSerializationBlockBuilder then_block_builder(
- ser_then_block, block_builder->region_builder, &then_block);
- block_builder->region_builder->block_builders.push_back(&then_block_builder);
- if (then_block_builder.BuildAllOpsInBlock(then_yields).failed()) {
+ auto tsh = GetTsh();
+
+ mlir::Region &then_region = op.getRegion(0);
+ mlir::Region &else_region = op.getRegion(1);
+
+ const std::string then_region_name = op_name + "_then_region";
+ TosaSerializationRegion *ser_then_region =
+ BuildRegion(then_region, then_region_name, isolated_from_above,
+ curr_region_builder, tsh, then_yields);
+ if (!ser_then_region) {
return nullptr;
}
if (then_yields.size() != op.getNumResults()) {
@@ -1229,20 +1283,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
return nullptr;
}
- // Building else branch block
- std::string else_block_name =
- "bb" + std::to_string(block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().size());
- ser_else_block = new TosaSerializationBasicBlock(
- else_block_name, region_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(ser_else_block);
- block_builder->region_builder->GetTsh()->GetMainRegion()->GetBlocks().push_back(ser_else_block);
-
- TosaSerializationBlockBuilder else_block_builder(
- ser_else_block, block_builder->region_builder, &else_block);
- block_builder->region_builder->block_builders.push_back(&else_block_builder);
- if (else_block_builder.BuildAllOpsInBlock(else_yields).failed()) {
+ const std::string else_region_name = op_name + "_else_region";
+ TosaSerializationRegion *ser_else_region =
+ BuildRegion(else_region, else_region_name, isolated_from_above,
+ curr_region_builder, tsh, else_yields);
+ if (!ser_else_region) {
return nullptr;
}
if (else_yields.size() != op.getNumResults()) {
@@ -1251,7 +1296,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
return nullptr;
}
- TosaCondIfAttribute attribute(ser_then_block->GetName(), ser_else_block->GetName());
+ TosaCondIfAttribute attribute(then_region_name, else_region_name);
for (size_t i = 0; i < op.getNumOperands(); i++) {
std::string input_name = GetTensorName(op.getOperand(i));
@@ -1263,9 +1308,9 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::IfOp>(
output_names.push_back(output_name);
}
- TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_COND_IF, Attribute_CondIfAttribute, &attribute,
- input_names, output_names);
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_COND_IF, Attribute_CondIfAttribute,
+ &attribute, input_names, output_names);
return tyop;
}
@@ -1274,32 +1319,22 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
mlir::Operation &op) const {
+ const std::string op_name = op.getName().getStringRef().str();
+ const bool isolated_from_above =
+ op.hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+ auto curr_region_builder = GetRegionBuilder();
std::vector<std::string> input_names, output_names;
+ auto tsh = GetTsh();
- mlir::Block& cond_block = op.getRegion(0).front();
- mlir::Block& body_block = op.getRegion(1).front();
+ mlir::Region &cond_region = op.getRegion(0);
+ mlir::Region &body_region = op.getRegion(1);
std::vector<mlir::Value> cond_yields, body_yields;
- TosaSerializationBasicBlock* ser_cond_block = nullptr;
- TosaSerializationBasicBlock* ser_body_block = nullptr;
-
- // Building cond branch block
- std::string cond_block_name =
- "bb" + std::to_string(block_builder->region_builder->getNumBlocksInRegion());
-
- std::string region_name = block_builder->region_builder->ser_region->GetName();
- ser_cond_block = new TosaSerializationBasicBlock(
- cond_block_name, region_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(ser_cond_block);
- block_builder->region_builder->ser_region->GetBlocks().push_back(ser_cond_block);
-
- TosaSerializationBlockBuilder cond_block_builder(
- ser_cond_block, block_builder->region_builder, &cond_block);
- block_builder->region_builder->block_builders.push_back(&cond_block_builder);
-
- if (cond_block_builder.BuildAllOpsInBlock(cond_yields).failed()) {
+ const std::string cond_region_name = op_name + "_cond_region";
+ TosaSerializationRegion *ser_cond_region =
+ BuildRegion(cond_region, cond_region_name, isolated_from_above,
+ curr_region_builder, tsh, cond_yields);
+ if (!ser_cond_region) {
return nullptr;
}
if (cond_yields.size() != 1) {
@@ -1307,21 +1342,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
return nullptr;
}
-
- // Building body branch block
- std::string body_block_name =
- "bb" + std::to_string(block_builder->region_builder->getNumBlocksInRegion());
- ser_body_block = new TosaSerializationBasicBlock(
- body_block_name, region_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>());
- assert(ser_body_block);
- block_builder->region_builder->ser_region->GetBlocks().push_back(ser_body_block);
-
- TosaSerializationBlockBuilder body_block_builder(
- ser_body_block, block_builder->region_builder, &body_block);
- block_builder->region_builder->block_builders.push_back(&body_block_builder);
- if (body_block_builder.BuildAllOpsInBlock(body_yields).failed()) {
+ const std::string body_region_name = op_name + "_body_region";
+ TosaSerializationRegion *ser_body_region =
+ BuildRegion(body_region, body_region_name, isolated_from_above,
+ curr_region_builder, tsh, body_yields);
+ if (!ser_body_region) {
return nullptr;
}
if (body_yields.size() != op.getNumResults()) {
@@ -1330,9 +1355,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
return nullptr;
}
-
- TosaWhileLoopAttribute attribute(ser_cond_block->GetName(),
- ser_body_block->GetName());
+ TosaWhileLoopAttribute attribute(cond_region_name, body_region_name);
for (size_t i = 0; i < op.getNumOperands(); i++) {
std::string input_name = GetTensorName(op.getOperand(i));
@@ -1344,9 +1367,9 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::WhileOp>(
output_names.push_back(output_name);
}
- TosaSerializationOperator *tyop = new TosaSerializationOperator(
- Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute,
- input_names, output_names);
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_WHILE_LOOP, Attribute_WhileLoopAttribute,
+ &attribute, input_names, output_names);
return tyop;
}
@@ -1390,16 +1413,20 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FFT2dOp>(
}
/* End translating TOSA operator */
-mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(std::vector<mlir::Value>& return_values) {
+mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(
+ bool is_top, std::vector<mlir::Value> &return_values) {
std::string region_name = ser_region->GetName();
- // this will likely run once for most cases.
+ int block_index = 0;
for (auto& block : this->region->getBlocks()) {
- // TODO: update the block name
- TosaSerializationBasicBlock* ser_block = new TosaSerializationBasicBlock(
- std::string("main"), region_name, std::vector<TosaSerializationOperator *>(),
- std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
- std::vector<std::string>()
- );
+ // must name first block of top region "main"
+ const std::string block_name =
+ (is_top && block_index == 0)
+ ? "main"
+ : (region_name + "_bb" + std::to_string(block_index++));
+ TosaSerializationBasicBlock *ser_block = new TosaSerializationBasicBlock(
+ block_name, region_name, std::vector<TosaSerializationOperator *>(),
+ std::vector<TosaSerializationTensor *>(), std::vector<std::string>(),
+ std::vector<std::string>());
// build the block
TosaSerializationBlockBuilder block_builder(ser_block, this, &block);
@@ -1421,16 +1448,11 @@ mlir::LogicalResult TosaSerializationRegionBuilder::BuildAllBlocksInRegion(std::
return mlir::success();
}
-
-
mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
std::vector<mlir::Value> &return_values) {
TosaSerializationOperator *ser_operator = nullptr;
TosaSerializationTensor *ser_tensor = nullptr;
size_t num_blocks_in_region = 0;
- static int input_tensor_index = 0;
- static int intermediate_tensor_index = 0;
- static int output_tensor_index = 0;
TosaSerializationOperatorBuilder op_builder(this);
// Specify block input tensor name
@@ -1528,7 +1550,6 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
TosaSerializationOperator *
TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op) {
- std::string full_op_name = op.getName().getStringRef().str();
TosaSerializationOperator *target_operator = nullptr;
if (false) {
@@ -1600,8 +1621,6 @@ TosaSerializationBlockBuilder::BuildTosaSerializationTensor(
mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func,
TosaSerializationHandler &tsh) {
- TosaSerializationRegion* ser_main_region;
-
mlir::Region *main_region = func.getCallableRegion();
std::vector<mlir::Value> main_returns;
@@ -1616,15 +1635,16 @@ mlir::LogicalResult translate2FlatBuffer(mlir::func::FuncOp &func,
return mlir::failure();
}
- ser_main_region = new TosaSerializationRegion(
- std::string("main"), /* region_name */
- std::vector<TosaSerializationBasicBlock*>() /* empty serialized block container */
- );
- assert(ser_main_region);
- tsh.GetRegions().push_back(ser_main_region);
+ // reset static counters
+ input_tensor_index = 0;
+ intermediate_tensor_index = 0;
+ output_tensor_index = 0;
- TosaSerializationRegionBuilder region_builder(ser_main_region, &tsh, main_region);
- if (region_builder.BuildAllBlocksInRegion(main_returns).failed()) {
+ TosaSerializationRegion *ser_main_region =
+ BuildRegion(*main_region, "main", /* isolated_from_above = */ true,
+ /* parent_value_scope = */ nullptr, &tsh, main_returns,
+ /* is_top = */ true);
+ if (!ser_main_region) {
return mlir::failure();
}