aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp290
1 files changed, 211 insertions, 79 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 3d38e1b..495d6f0 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -25,9 +25,9 @@
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tosa_serialization_handler.h"
#include <functional>
-#include <map>
#include <queue>
#include <unordered_map>
+#include <unordered_set>
#include <vector>
// The namespace might be confusing here. We have mlir::tosa:: defined in MLIR
@@ -155,10 +155,12 @@ public:
TosaMlirOperatorBuilder(
mlir::OpBuilder *_op_builder, TosaSerializationBasicBlock *_ser_block,
mlir::Block *_block, mlir::Location _loc,
+ TosaMlirBlockBuilder *_block_builder,
std::unordered_map<std::string, mlir::Value> *_tensor_map,
std::unordered_map<std::string, mlir::RankedTensorType> *_tensor_type_map)
: op_builder(_op_builder), ser_block(_ser_block), block(_block),
- loc(_loc), tensor_map(_tensor_map), tensor_type_map(_tensor_type_map) {}
+ loc(_loc), block_builder(_block_builder), tensor_map(_tensor_map),
+ tensor_type_map(_tensor_type_map) {}
template <Op OPCODE>
std::vector<mlir::Value> build(TosaSerializationOperator *op) const;
@@ -179,6 +181,9 @@ public:
return op_string;
}
+ TosaSerializationHandler *GetTsh() const;
+ TosaMlirRegionBuilder *GetRegionBuilder() const;
+
private:
template <class MLIR_OP>
std::vector<mlir::Value>
@@ -199,6 +204,7 @@ private:
TosaSerializationBasicBlock *ser_block;
mlir::Block *block;
mlir::Location loc;
+ TosaMlirBlockBuilder *block_builder;
std::unordered_map<std::string, mlir::Value> *tensor_map;
std::unordered_map<std::string, mlir::RankedTensorType> *tensor_type_map;
};
@@ -208,7 +214,7 @@ template <Op OPCODE>
std::vector<mlir::Value> TosaMlirOperatorBuilder::build(TosaSerializationOperator* op) const
{
llvm::errs() << "ERROR: " << get_string(op) << " translation hasn't been implemented\n";
- return std::vector<mlir::Value>();
+ return {};
}
// BUILD_OP_POOL2D(MaxPool2d, MAX_POOL2D)
@@ -428,7 +434,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CONST>(TosaSerializat
default:
llvm::errs() << "ERROR: " << get_string(op)
<< " contains unsupported element type\n";
- return std::vector<mlir::Value>();
+ return {};
}
mlir::Operation *mlir_op =
op_builder->create<mlir::tosa::ConstOp>(loc, output_type, value_attr);
@@ -1020,25 +1026,6 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_CUSTOM>(TosaSerializa
return std::vector<mlir::Value>({mlir_op->getResult(0)});
}
-// TosaSerializationBasicBlock
-// template <>
-// std::vector<mlir::Value>
-// TosaMlirOperatorBuilder::build<COND_IF>(TosaSerializationOperator* op) const
-//{
-// // todo: mlir::tosa::IfOp
-// return {};
-//}
-
-// TosaSerializationBasicBlock
-// template <>
-// std::vector<mlir::Value>
-// TosaMlirOperatorBuilder::build<WHILE_LOOP>(TosaSerializationOperator* op)
-// const
-//{
-// // todo: mlir::tosa::WhileOp
-// return {};
-//}
-
template <>
std::vector<mlir::Value>
TosaMlirOperatorBuilder::build<Op_RFFT2D>(TosaSerializationOperator *op) const {
@@ -1081,17 +1068,28 @@ TosaMlirOperatorBuilder::build<Op_FFT2D>(TosaSerializationOperator *op) const {
class TosaMlirRegionBuilder {
public:
- TosaMlirRegionBuilder(TosaSerializationRegion* _ser_region,
- TosaSerializationHandler* _tsh,
- mlir::Region* _region,
- mlir::OpBuilder* _op_builder,
- mlir::Location _loc)
- : ser_region(_ser_region), tsh(_tsh), region(_region), op_builder(_op_builder), loc(_loc) {}
+ TosaMlirRegionBuilder(TosaSerializationRegion *_ser_region,
+ TosaSerializationHandler *_tsh, mlir::Region *_region,
+ mlir::OpBuilder *_op_builder, mlir::Location _loc,
+ TosaMlirRegionBuilder *_parent_value_scope = nullptr)
+ : ser_region(_ser_region), tsh(_tsh), region(_region),
+ op_builder(_op_builder), loc(_loc) {
+ if (_parent_value_scope) {
+ // inherit parent_value_scope's tensor_map
+ for (auto &kv : _parent_value_scope->GetTensorMap()) {
+ tensor_map.insert(kv);
+ }
+ }
+ }
mlir::LogicalResult BuildAllBlocksInRegion(std::vector<mlir::Value>& return_values);
mlir::OpBuilder* GetOpBuilder() { return op_builder; }
mlir::Location GetLocation() { return loc; }
+ std::unordered_map<std::string, mlir::Value> &GetTensorMap() {
+ return tensor_map;
+ }
+ TosaSerializationHandler *GetTsh() const { return tsh; }
private:
mlir::Region* region;
@@ -1099,7 +1097,7 @@ private:
TosaSerializationHandler* tsh;
mlir::OpBuilder* op_builder;
mlir::Location loc;
- std::vector<TosaMlirBlockBuilder*> block_builders;
+ std::unordered_map<std::string, mlir::Value> tensor_map;
};
class TosaMlirBlockBuilder {
@@ -1115,28 +1113,185 @@ public:
mlir::OpBuilder* GetOpBuilder() { return region_builder->GetOpBuilder(); }
mlir::Location GetLocation() { return region_builder->GetLocation(); }
+ std::unordered_map<std::string, mlir::Value> &GetTensorMap() {
+ return region_builder->GetTensorMap();
+ }
+
+ TosaSerializationHandler *GetTsh() const { return region_builder->GetTsh(); }
+ TosaMlirRegionBuilder *GetRegionBuilder() const { return region_builder; }
private:
TosaSerializationBasicBlock* ser_block;
TosaMlirRegionBuilder* region_builder;
mlir::Block* block;
- std::unordered_map<std::string, mlir::Value> tensor_map;
std::unordered_map<std::string, mlir::RankedTensorType> tensor_type_map;
};
+TosaSerializationHandler *TosaMlirOperatorBuilder::GetTsh() const {
+ return block_builder->GetTsh();
+}
+
+TosaMlirRegionBuilder *TosaMlirOperatorBuilder::GetRegionBuilder() const {
+ return block_builder->GetRegionBuilder();
+}
+
+// build control flow ops:
+
+namespace {
+
+mlir::LogicalResult
+BuildRegion(TosaSerializationRegion *ser_region, TosaSerializationHandler *tsh,
+ mlir::Region *mlir_region, mlir::OpBuilder *op_builder,
+ mlir::Location loc, std::vector<mlir::Value> &return_values,
+ bool isolated_from_above = false,
+ TosaMlirRegionBuilder *parent_region_builder = nullptr) {
+ TosaMlirRegionBuilder *parent_value_scope =
+ isolated_from_above ? nullptr : parent_region_builder;
+ TosaMlirRegionBuilder region_builder(ser_region, tsh, mlir_region, op_builder,
+ loc, parent_value_scope);
+ return region_builder.BuildAllBlocksInRegion(return_values);
+}
+
+} // namespace
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_COND_IF>(
+ TosaSerializationOperator *op) const {
+ mlir::Value cond_val = tensor_map->at(op->GetInputTensorNames().at(0));
+ std::vector<mlir::Value> input_values;
+ for (auto idx = 1u; idx < op->GetInputTensorNames().size(); idx++) {
+ input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx)));
+ }
+ std::vector<mlir::Type> output_types;
+ for (auto &name : op->GetInputTensorNames()) {
+ output_types.push_back(tensor_type_map->at(name));
+ }
+
+ assert(op->GetAttributeType() ==
+ Attribute_CondIfAttribute); // double check attribute type
+ TosaCondIfAttribute *attr =
+ static_cast<TosaCondIfAttribute *>(op->GetAttribute());
+ auto ser_then_region = GetTsh()->GetRegionByName(attr->then_branch());
+ auto ser_else_region = GetTsh()->GetRegionByName(attr->else_branch());
+
+ if (!ser_then_region || !ser_else_region) {
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " region serialization hasn't been implemented\n";
+ return {};
+ }
+
+ mlir::Operation *mlir_op = op_builder->create<mlir::tosa::IfOp>(
+ loc, output_types, cond_val, input_values);
+
+ const bool isolated_from_above =
+ mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+ mlir::Region &then_region = mlir_op->getRegion(0);
+ mlir::Region &else_region = mlir_op->getRegion(1);
+
+ auto curr_region_builder = GetRegionBuilder();
+
+ std::vector<mlir::Value> then_returns, else_returns;
+
+ if (BuildRegion(ser_then_region, GetTsh(), &then_region, op_builder, loc,
+ then_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (then_returns.size() != mlir_op->getNumResults()) {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " then_region yield.size() doesn't match cond_if's output size\n";
+ return {};
+ }
+
+ if (BuildRegion(ser_else_region, GetTsh(), &else_region, op_builder, loc,
+ else_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (else_returns.size() != mlir_op->getNumResults()) {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " else_region yield.size() doesn't match cond_if's output size\n";
+ return {};
+ }
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(mlir_op->getResults().begin(),
+ mlir_op->getResults().end());
+}
+
+template <>
+std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_WHILE_LOOP>(
+ TosaSerializationOperator *op) const {
+ std::vector<mlir::Value> input_values;
+ for (auto idx = 0u; idx < op->GetInputTensorNames().size(); idx++) {
+ input_values.push_back(tensor_map->at(op->GetInputTensorNames().at(idx)));
+ }
+ std::vector<mlir::Type> output_types;
+ for (auto &name : op->GetInputTensorNames()) {
+ output_types.push_back(tensor_type_map->at(name));
+ }
+ assert(op->GetAttributeType() ==
+ Attribute_WhileLoopAttribute); // double check attribute type
+ TosaWhileLoopAttribute *attr =
+ static_cast<TosaWhileLoopAttribute *>(op->GetAttribute());
+ auto ser_cond_region = GetTsh()->GetRegionByName(attr->cond_branch());
+ auto ser_body_region = GetTsh()->GetRegionByName(attr->body_branch());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::WhileOp>(loc, output_types, input_values);
+
+ const bool isolated_from_above =
+ mlir_op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>();
+
+ mlir::Region &cond_region = mlir_op->getRegion(0);
+ mlir::Region &body_region = mlir_op->getRegion(1);
+
+ auto curr_region_builder = GetRegionBuilder();
+
+ std::vector<mlir::Value> cond_returns, body_returns;
+
+ if (BuildRegion(ser_cond_region, GetTsh(), &cond_region, op_builder, loc,
+ cond_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (cond_returns.size() != 1) {
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " cond_region yield.size() is not 1\n";
+ return {};
+ }
+
+ if (BuildRegion(ser_body_region, GetTsh(), &body_region, op_builder, loc,
+ body_returns, isolated_from_above, curr_region_builder)
+ .failed()) {
+ return {};
+ }
+ if (body_returns.size() != mlir_op->getNumResults()) {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " body_region yield.size() doesn't match while_loop's output size\n";
+ return {};
+ }
+
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>(mlir_op->getResults().begin(),
+ mlir_op->getResults().end());
+}
+
mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
std::vector<mlir::Value> &return_values) {
block->clear();
auto loc = GetLocation();
auto op_builder = GetOpBuilder();
+ auto &tensor_map = GetTensorMap();
- std::unordered_map<std::string, std::vector<TosaSerializationOperator*>> consumer_map;
- std::unordered_map<std::string, bool> tensor_built;
- std::unordered_map<TosaSerializationOperator*, bool> operator_built;
+ std::unordered_set<TosaSerializationOperator *> operator_built;
std::queue<TosaSerializationOperator*> operator_queue;
TosaMlirOperatorBuilder tosa_op_builder(op_builder, ser_block, block, loc,
- &tensor_map, &tensor_type_map);
+ this, &tensor_map, &tensor_type_map);
for (auto ts : ser_block->GetTensors()) {
mlir::RankedTensorType type;
@@ -1145,28 +1300,19 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
}
const auto& ts_name = ts->GetName();
tensor_type_map[ts_name] = type;
- tensor_built[ts_name] = false;
}
- for (auto op : ser_block->GetOperators()) {
- operator_built[op] = false;
- for (auto ts_name : op->GetInputTensorNames()) {
- consumer_map[ts_name].push_back(op);
- }
- }
-
- // Update operator_queue if a consumer of tensor_name has all of its inputs already built
- auto queue_ready_consumers = [&](const std::string tensor_name) {
- for (auto consumer_op : consumer_map[tensor_name]) {
- // Sanity check operator hasn't been built
- if (operator_built[consumer_op]) {
- llvm::errs() << "ERROR: " << tosa_op_builder.get_string(consumer_op)
- << " is already built before its input is built\n";
- assert(0);
+ // Update operator_queue with operators whose inputs are all built
+ auto queue_ready_operators = [&]() {
+ // note: it is important to queue in order of original operators
+ // because an operator input may be defined in upper value scopes
+ for (auto consumer_op : ser_block->GetOperators()) {
+ if (operator_built.count(consumer_op)) {
+ continue;
}
bool all_inputs_ready = true;
for (const auto& input_name : consumer_op->GetInputTensorNames()) {
- if (!tensor_built[input_name]) {
+ if (!tensor_map.count(input_name)) {
all_inputs_ready = false;
break;
}
@@ -1177,7 +1323,7 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
}
};
- // Initialize tensor_map/tensor_built/operator_queue based on block input arguments
+ // Initialize tensor_map/operator_queue based on block input arguments
for (const std::string& block_input_name : ser_block->GetInputs()) {
auto type = tensor_type_map[block_input_name];
auto input_value = block->addArgument(type, loc);
@@ -1185,29 +1331,21 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
llvm::errs() << "ERROR: block input tensor " << block_input_name << " already exists\n";
return mlir::failure();
}
- tensor_built[block_input_name] = true;
tensor_map[block_input_name] = input_value;
- queue_ready_consumers(block_input_name);
}
- // add all operators with 0 inputs (e.g., constant operators) to
- // operator_queue
- for (auto op : ser_block->GetOperators()) {
- if (op->GetInputTensorNames().empty()) {
- operator_queue.push(op);
- }
- }
+ queue_ready_operators();
while (!operator_queue.empty()) {
TosaSerializationOperator* op = operator_queue.front();
operator_queue.pop();
// skip if operator has been built
- if (operator_built[op]) {
+ if (operator_built.count(op)) {
// this happens when same input appears twice or more in operator, eg, concat(%0, %0)
continue;
}
- operator_built[op] = true;
+ operator_built.insert(op);
std::vector<mlir::Value> output_values;
if (false) {
@@ -1232,22 +1370,23 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock(
for (size_t i = 0; i < output_values.size(); i++) {
// Sanity check tensor hasn't been built
std::string op_output_name = op->GetOutputTensorNames()[i];
- if (tensor_built[op_output_name]) {
+ if (tensor_map.count(op_output_name)) {
llvm::errs() << "ERROR: tensor " << op_output_name << " is already built\n";
return mlir::failure();
}
tensor_map[op_output_name] = output_values[i];
- tensor_built[op_output_name] = true;
- queue_ready_consumers(op_output_name);
}
+ // look for any more ready consumers
+ queue_ready_operators();
}
// Construct return values
std::vector<mlir::Value> return_operands;
for (const auto& output_name : ser_block->GetOutputs()) {
// Sanity check if terminator mlir::Value is built
- if (!tensor_built[output_name]) {
- llvm::errs() << "ERROR: terminator mlir::Value " << output_name << " is not built\n";
+ if (!tensor_map.count(output_name)) {
+ llvm::errs() << "ERROR: terminator mlir::Value " << output_name
+ << " is not built in block " << ser_block->GetName() << "\n";
return mlir::failure();
}
mlir::Value output_value = tensor_map.at(output_name);
@@ -1267,8 +1406,6 @@ mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion(
for (auto& ser_block : ser_region->GetBlocks()) {
auto& block = region->emplaceBlock();
TosaMlirBlockBuilder block_builder(ser_block, this, &block);
- // Region Builders need access to block builders (?)
- block_builders.push_back(&block_builder);
if (block_builder.BuildAllOpsInBlock(return_values).failed()) {
return mlir::failure();
@@ -1291,12 +1428,6 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func,
return mlir::failure();
}
- if (tsh.GetRegions().size() != 1) {
- llvm::errs() << "Internal Error: TosaSerializationHandler's region list "
- "must contain exactly one region\n";
- return mlir::failure();
- }
-
TosaSerializationRegion* ser_main_region = tsh.GetRegions().front();
auto loc = func.getLoc();
@@ -1305,8 +1436,9 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func,
main_region->takeBody(*main_region); // empty old func body
auto op_builder = mlir::OpBuilder(func.getBody());
- TosaMlirRegionBuilder region_builder(ser_main_region, &tsh, main_region, &op_builder, loc);
- if (region_builder.BuildAllBlocksInRegion(main_returns).failed()) {
+ if (BuildRegion(ser_main_region, &tsh, main_region, &op_builder, loc,
+ main_returns)
+ .failed()) {
return mlir::failure();
}