aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp213
1 files changed, 190 insertions, 23 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 495d6f0..196c8f6 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -45,6 +45,13 @@ llvm::cl::opt<std::string> tosa_deserialize_schema(
"tosa-deserialize-schema", llvm::cl::desc("<tosa flatbuffer schema file>"),
llvm::cl::init(""), llvm::cl::value_desc("filename"));
+const std::string kDefaultExportedName = "tosa_deserialized";
+const std::string kDefaultInputPrefix = "input_";
+const std::string kDefaultOutputPrefix = "output_";
+const std::string kDefaultFBSDescription = "Tosa FBS Converted";
+const std::string kDefaultJSONDescription = "Tosa JSON Converted";
+const std::string kMainFunctionName = "main";
+
namespace {
// construct tensor type from dtype and shape of TosaSerializationTensor
@@ -1418,9 +1425,10 @@ mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion(
return mlir::success();
}
-mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func,
- mlir::MLIRContext& context,
- tosa::TosaSerializationHandler& tsh) {
+mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp &func,
+ mlir::MLIRContext &context,
+ tosa::TosaSerializationHandler &tsh,
+ std::vector<mlir::Value> &main_returns) {
mlir::Region* main_region = func.getCallableRegion();
if (!main_region) {
@@ -1431,7 +1439,6 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func,
TosaSerializationRegion* ser_main_region = tsh.GetRegions().front();
auto loc = func.getLoc();
- std::vector<mlir::Value> main_returns;
main_region->takeBody(*main_region); // empty old func body
auto op_builder = mlir::OpBuilder(func.getBody());
@@ -1465,26 +1472,179 @@ mlir::LogicalResult loadTosaSchema(tosa::TosaSerializationHandler& tsh) {
return mlir::success();
}
+namespace {
+
+mlir::NamedAttribute DefaultEntryFuncitonAttr(mlir::Builder &builder,
+ bool is_input, int count) {
+ std::string names;
+ for (int i = 0; i < count; i++) {
+ std::string name = kDefaultExportedName + "_";
+ name += (is_input ? kDefaultInputPrefix : kDefaultOutputPrefix);
+ name += std::to_string(i) + ":0";
+ if (i > 0) {
+ names += ",";
+ }
+ names += name;
+ }
+ return builder.getNamedAttr((is_input ? "inputs" : "outputs"),
+ builder.getStringAttr(names));
+}
+
+// erase function attrs and empty function region'd body
+void ResetFunction(mlir::func::FuncOp &function, mlir::MLIRContext &context) {
+ function->setAttrs(mlir::DictionaryAttr::get(&context, {}));
+ mlir::Region *main_region = function.getCallableRegion();
+ main_region->takeBody(*main_region);
+}
+
+// replace attrs and body of @a to_function and its parent module
+// by @a from_module and its "main" function
+mlir::LogicalResult CloneIntoModuleAndFunction(
+ mlir::MLIRContext &context, mlir::func::FuncOp &to_function,
+ mlir::ModuleOp &to_module, mlir::func::FuncOp &from_function,
+ mlir::ModuleOp &from_module) {
+ // copy all attrs from new_module to module
+ to_module->setAttrs(from_module->getAttrDictionary());
+ // erase attrs and body of function
+ ResetFunction(to_function, context);
+ // clone new_func attrs and region into function
+ mlir::IRMapping mapping;
+ from_function.cloneInto(to_function, mapping);
+ return mlir::success();
+}
+
+} // namespace
+
namespace mlir {
namespace tosa {
+mlir::OwningOpRef<mlir::ModuleOp>
+BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context,
+ bool file_is_fbs) {
+ TosaSerializationHandler tsh;
+ if (file_is_fbs) {
+ if (tsh.LoadFileTosaFlatbuffer(file_name)) {
+ llvm::errs() << "Fail to load TOSA file " << file_name << "\n";
+ return nullptr;
+ }
+ } else {
+ // must load tosa schema before loading json file
+ if (loadTosaSchema(tsh).failed()) {
+ return nullptr;
+ }
+ if (tsh.LoadFileJson(file_name)) {
+ llvm::errs() << "Fail to load TOSA JSON file " << file_name << "\n";
+ return nullptr;
+ }
+ }
+
+ // create new module
+ auto base_loc = mlir::FileLineColLoc::get(context, file_name, 0, 0);
+ auto module = mlir::ModuleOp::create(base_loc);
+
+ // set module attributes
+ const auto &tosa_version = tsh.GetVersion().to_string();
+ std::string tosa_description =
+ file_is_fbs ? kDefaultFBSDescription : kDefaultJSONDescription;
+ auto builder = mlir::Builder(context);
+ module->setAttr("tosa.fbs_version", builder.getStringAttr(tosa_version));
+ module->setAttr("tosa.description", builder.getStringAttr(tosa_description));
+ module->setAttr("tf_saved_model.semantics", mlir::UnitAttr::get(context));
+
+ // construct function with input and return types
+ llvm::SmallVector<mlir::Type, 2> ret_types;
+ llvm::SmallVector<mlir::Type, 4> input_types;
+ auto func_type = builder.getFunctionType(input_types, ret_types);
+ auto func_loc =
+ mlir::NameLoc::get(builder.getStringAttr(kMainFunctionName), base_loc);
+ auto func = mlir::func::FuncOp::create(func_loc, kMainFunctionName, func_type,
+ /* attrs= */ {});
+ func.addEntryBlock();
+
+ // deserialize tosa fbs into function
+ std::vector<mlir::Value> main_returns;
+ if (buildTosaMlir(func, *context, tsh, main_returns).failed()) {
+ llvm::errs() << "Failed to deserialize flatbuffer "
+ << tosa_deserialize_filename << "\n";
+ return nullptr;
+ }
+ auto main_args = func.getCallableRegion()->getArguments();
+ // extract function input types
+ for (auto arg : main_args) {
+ input_types.push_back(arg.getType());
+ }
+ // extract function return types
+ for (auto ret : main_returns) {
+ ret_types.push_back(ret.getType());
+ }
+ // set function type with full input and return types
+ func_type = builder.getFunctionType(input_types, ret_types);
+ func.setType(func_type);
+
+ // set function attributes
+ llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
+ if (!input_types.empty()) {
+ attributes.push_back(DefaultEntryFuncitonAttr(
+ builder, /* is_input = */ true, /* count = */ input_types.size()));
+ for (int i = 0; i < input_types.size(); i++) {
+ std::string input_i = kDefaultInputPrefix + std::to_string(i);
+ func.setArgAttr(i, "tf_saved_model.index_path",
+ mlir::ArrayAttr::get(
+ context, {mlir::StringAttr::get(context, input_i)}));
+ }
+ }
+ if (!ret_types.empty()) {
+ attributes.push_back(DefaultEntryFuncitonAttr(
+ builder, /* is_input = */ false, /* count = */ ret_types.size()));
+ for (int i = 0; i < ret_types.size(); i++) {
+ std::string output_i = kDefaultOutputPrefix + std::to_string(i);
+ func.setResultAttr(
+ i, "tf_saved_model.index_path",
+ mlir::ArrayAttr::get(context,
+ {mlir::StringAttr::get(context, output_i)}));
+ }
+ }
+ func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
+ func->setAttr(
+ "tf_saved_model.exported_names",
+ mlir::ArrayAttr::get(
+ context, {mlir::StringAttr::get(context, kDefaultExportedName)}));
+
+ // add func to module
+ module.push_back(std::move(func));
+ return mlir::OwningOpRef<mlir::ModuleOp>(module);
+}
+
namespace {
class TosaDeserialize : public TosaDeserializationPassBase<TosaDeserialize> {
public:
void runOnOperation() final {
- TosaSerializationHandler tsh;
- if (tsh.LoadFileTosaFlatbuffer(tosa_deserialize_filename.c_str())) {
- llvm::errs() << "Fail to load TOSA file " << tosa_deserialize_filename << "\n";
- return signalPassFailure();
- }
-
auto function = getOperation();
auto& context = getContext();
- if (buildTosaMlir(function, context, tsh).failed()) {
- llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n";
+ auto new_module_ref = BuildMlirFromTosaFile(
+ tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ true);
+ if (!new_module_ref) {
+ return signalPassFailure();
+ }
+
+ mlir::ModuleOp new_module = *new_module_ref;
+ auto builder = mlir::Builder(&context);
+ auto module = function->getParentOfType<mlir::ModuleOp>();
+ auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>(
+ builder.getStringAttr(kMainFunctionName));
+ if (!new_function) {
+ llvm::errs() << "Failed to find main function in deserialized module\n";
+ return signalPassFailure();
+ }
+ if (CloneIntoModuleAndFunction(context,
+ /* to_function = */ function,
+ /* to_module = */ module,
+ /* from_function = */ new_function,
+ /* from_module = */ new_module)
+ .failed()) {
return signalPassFailure();
}
}
@@ -1494,23 +1654,30 @@ class TosaDeserializeJSON
: public TosaDeserializationJSONPassBase<TosaDeserializeJSON> {
public:
void runOnOperation() final {
- TosaSerializationHandler tsh;
+ auto function = getOperation();
+ auto &context = getContext();
- // must load tosa schema before loading json file
- if (loadTosaSchema(tsh).failed()) {
+ auto new_module_ref = BuildMlirFromTosaFile(
+ tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ false);
+ if (!new_module_ref) {
return signalPassFailure();
}
- if (tsh.LoadFileJson(tosa_deserialize_filename.c_str())) {
- llvm::errs() << "Fail to load TOSA JSON file " << tosa_deserialize_filename << "\n";
+ mlir::ModuleOp new_module = *new_module_ref;
+ auto builder = mlir::Builder(&context);
+ auto module = function->getParentOfType<mlir::ModuleOp>();
+ auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>(
+ builder.getStringAttr(kMainFunctionName));
+ if (!new_function) {
+ llvm::errs() << "Failed to find main function in deserialized module\n";
return signalPassFailure();
}
-
- auto function = getOperation();
- auto& context = getContext();
-
- if (buildTosaMlir(function, context, tsh).failed()) {
- llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n";
+ if (CloneIntoModuleAndFunction(context,
+ /* to_function = */ function,
+ /* to_module = */ module,
+ /* from_function = */ new_function,
+ /* from_module = */ new_module)
+ .failed()) {
return signalPassFailure();
}
}