diff options
-rw-r--r-- | include/DeserializationPasses.h | 9 | ||||
-rw-r--r-- | src/TosaDeserialize.cpp | 213 |
2 files changed, 199 insertions, 23 deletions
diff --git a/include/DeserializationPasses.h b/include/DeserializationPasses.h index 1bc195a..1a38814 100644 --- a/include/DeserializationPasses.h +++ b/include/DeserializationPasses.h @@ -19,6 +19,8 @@ #include <memory> #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { @@ -27,6 +29,13 @@ namespace tosa { std::unique_ptr<Pass> createTosaDeserializePass(); std::unique_ptr<Pass> createTosaDeserializeJSONPass(); +// deserializes a tosa file and return an mlir module +// if file_is_fbs is true, then treat file_name as a tosa flatbuffer file +// otherwise, treat file_name as a tosa json file +mlir::OwningOpRef<mlir::ModuleOp> +BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context, + bool file_is_fbs = true); + #define GEN_PASS_REGISTRATION #define GEN_PASS_CLASSES #include "include/DeserializationPasses.h.inc" diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 495d6f0..196c8f6 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -45,6 +45,13 @@ llvm::cl::opt<std::string> tosa_deserialize_schema( "tosa-deserialize-schema", llvm::cl::desc("<tosa flatbuffer schema file>"), llvm::cl::init(""), llvm::cl::value_desc("filename")); +const std::string kDefaultExportedName = "tosa_deserialized"; +const std::string kDefaultInputPrefix = "input_"; +const std::string kDefaultOutputPrefix = "output_"; +const std::string kDefaultFBSDescription = "Tosa FBS Converted"; +const std::string kDefaultJSONDescription = "Tosa JSON Converted"; +const std::string kMainFunctionName = "main"; + namespace { // construct tensor type from dtype and shape of TosaSerializationTensor @@ -1418,9 +1425,10 @@ mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion( return mlir::success(); } -mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func, - mlir::MLIRContext& context, - tosa::TosaSerializationHandler& tsh) { +mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp &func, + mlir::MLIRContext &context, + tosa::TosaSerializationHandler &tsh, + std::vector<mlir::Value> &main_returns) { mlir::Region* main_region = func.getCallableRegion(); if (!main_region) { @@ -1431,7 +1439,6 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func, TosaSerializationRegion* ser_main_region = tsh.GetRegions().front(); auto loc = func.getLoc(); - std::vector<mlir::Value> main_returns; main_region->takeBody(*main_region); // empty old func body auto op_builder = mlir::OpBuilder(func.getBody()); @@ -1465,26 +1472,179 @@ mlir::LogicalResult loadTosaSchema(tosa::TosaSerializationHandler& tsh) { return mlir::success(); } +namespace { + +mlir::NamedAttribute DefaultEntryFuncitonAttr(mlir::Builder &builder, + bool is_input, int count) { + std::string names; + for (int i = 0; i < count; i++) { + std::string name = kDefaultExportedName + "_"; + name += (is_input ? kDefaultInputPrefix : kDefaultOutputPrefix); + name += std::to_string(i) + ":0"; + if (i > 0) { + names += ","; + } + names += name; + } + return builder.getNamedAttr((is_input ? "inputs" : "outputs"), + builder.getStringAttr(names)); +} + +// erase function attrs and empty function region'd body +void ResetFunction(mlir::func::FuncOp &function, mlir::MLIRContext &context) { + function->setAttrs(mlir::DictionaryAttr::get(&context, {})); + mlir::Region *main_region = function.getCallableRegion(); + main_region->takeBody(*main_region); +} + +// replace attrs and body of @a to_function and its parent module +// by @a from_module and its "main" function +mlir::LogicalResult CloneIntoModuleAndFunction( + mlir::MLIRContext &context, mlir::func::FuncOp &to_function, + mlir::ModuleOp &to_module, mlir::func::FuncOp &from_function, + mlir::ModuleOp &from_module) { + // copy all attrs from new_module to module + to_module->setAttrs(from_module->getAttrDictionary()); + // erase attrs and body of function + ResetFunction(to_function, context); + // clone new_func attrs and region into function + mlir::IRMapping mapping; + from_function.cloneInto(to_function, mapping); + return mlir::success(); +} + +} // namespace + namespace mlir { namespace tosa { +mlir::OwningOpRef<mlir::ModuleOp> +BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context, + bool file_is_fbs) { + TosaSerializationHandler tsh; + if (file_is_fbs) { + if (tsh.LoadFileTosaFlatbuffer(file_name)) { + llvm::errs() << "Fail to load TOSA file " << file_name << "\n"; + return nullptr; + } + } else { + // must load tosa schema before loading json file + if (loadTosaSchema(tsh).failed()) { + return nullptr; + } + if (tsh.LoadFileJson(file_name)) { + llvm::errs() << "Fail to load TOSA JSON file " << file_name << "\n"; + return nullptr; + } + } + + // create new module + auto base_loc = mlir::FileLineColLoc::get(context, file_name, 0, 0); + auto module = mlir::ModuleOp::create(base_loc); + + // set module attributes + const auto &tosa_version = tsh.GetVersion().to_string(); + std::string tosa_description = + file_is_fbs ? kDefaultFBSDescription : kDefaultJSONDescription; + auto builder = mlir::Builder(context); + module->setAttr("tosa.fbs_version", builder.getStringAttr(tosa_version)); + module->setAttr("tosa.description", builder.getStringAttr(tosa_description)); + module->setAttr("tf_saved_model.semantics", mlir::UnitAttr::get(context)); + + // construct function with input and return types + llvm::SmallVector<mlir::Type, 2> ret_types; + llvm::SmallVector<mlir::Type, 4> input_types; + auto func_type = builder.getFunctionType(input_types, ret_types); + auto func_loc = + mlir::NameLoc::get(builder.getStringAttr(kMainFunctionName), base_loc); + auto func = mlir::func::FuncOp::create(func_loc, kMainFunctionName, func_type, + /* attrs= */ {}); + func.addEntryBlock(); + + // deserialize tosa fbs into function + std::vector<mlir::Value> main_returns; + if (buildTosaMlir(func, *context, tsh, main_returns).failed()) { + llvm::errs() << "Failed to deserialize flatbuffer " + << tosa_deserialize_filename << "\n"; + return nullptr; + } + auto main_args = func.getCallableRegion()->getArguments(); + // extract function input types + for (auto arg : main_args) { + input_types.push_back(arg.getType()); + } + // extract function return types + for (auto ret : main_returns) { + ret_types.push_back(ret.getType()); + } + // set function type with full input and return types + func_type = builder.getFunctionType(input_types, ret_types); + func.setType(func_type); + + // set function attributes + llvm::SmallVector<mlir::NamedAttribute, 2> attributes; + if (!input_types.empty()) { + attributes.push_back(DefaultEntryFuncitonAttr( + builder, /* is_input = */ true, /* count = */ input_types.size())); + for (int i = 0; i < input_types.size(); i++) { + std::string input_i = kDefaultInputPrefix + std::to_string(i); + func.setArgAttr(i, "tf_saved_model.index_path", + mlir::ArrayAttr::get( + context, {mlir::StringAttr::get(context, input_i)})); + } + } + if (!ret_types.empty()) { + attributes.push_back(DefaultEntryFuncitonAttr( + builder, /* is_input = */ false, /* count = */ ret_types.size())); + for (int i = 0; i < ret_types.size(); i++) { + std::string output_i = kDefaultOutputPrefix + std::to_string(i); + func.setResultAttr( + i, "tf_saved_model.index_path", + mlir::ArrayAttr::get(context, + {mlir::StringAttr::get(context, output_i)})); + } + } + func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes)); + func->setAttr( + "tf_saved_model.exported_names", + mlir::ArrayAttr::get( + context, {mlir::StringAttr::get(context, kDefaultExportedName)})); + + // add func to module + module.push_back(std::move(func)); + return mlir::OwningOpRef<mlir::ModuleOp>(module); +} + namespace { class TosaDeserialize : public TosaDeserializationPassBase<TosaDeserialize> { public: void runOnOperation() final { - TosaSerializationHandler tsh; - if (tsh.LoadFileTosaFlatbuffer(tosa_deserialize_filename.c_str())) { - llvm::errs() << "Fail to load TOSA file " << tosa_deserialize_filename << "\n"; - return signalPassFailure(); - } - auto function = getOperation(); auto& context = getContext(); - if (buildTosaMlir(function, context, tsh).failed()) { - llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n"; + auto new_module_ref = BuildMlirFromTosaFile( + tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ true); + if (!new_module_ref) { + return signalPassFailure(); + } + + mlir::ModuleOp new_module = *new_module_ref; + auto builder = mlir::Builder(&context); + auto module = function->getParentOfType<mlir::ModuleOp>(); + auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>( + builder.getStringAttr(kMainFunctionName)); + if (!new_function) { + llvm::errs() << "Failed to find main function in deserialized module\n"; + return signalPassFailure(); + } + if (CloneIntoModuleAndFunction(context, + /* to_function = */ function, + /* to_module = */ module, + /* from_function = */ new_function, + /* from_module = */ new_module) + .failed()) { return signalPassFailure(); } } @@ -1494,23 +1654,30 @@ class TosaDeserializeJSON : public TosaDeserializationJSONPassBase<TosaDeserializeJSON> { public: void runOnOperation() final { - TosaSerializationHandler tsh; + auto function = getOperation(); + auto &context = getContext(); - // must load tosa schema before loading json file - if (loadTosaSchema(tsh).failed()) { + auto new_module_ref = BuildMlirFromTosaFile( + tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ false); + if (!new_module_ref) { return signalPassFailure(); } - if (tsh.LoadFileJson(tosa_deserialize_filename.c_str())) { - llvm::errs() << "Fail to load TOSA JSON file " << tosa_deserialize_filename << "\n"; + mlir::ModuleOp new_module = *new_module_ref; + auto builder = mlir::Builder(&context); + auto module = function->getParentOfType<mlir::ModuleOp>(); + auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>( + builder.getStringAttr(kMainFunctionName)); + if (!new_function) { + llvm::errs() << "Failed to find main function in deserialized module\n"; return signalPassFailure(); } - - auto function = getOperation(); - auto& context = getContext(); - - if (buildTosaMlir(function, context, tsh).failed()) { - llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n"; + if (CloneIntoModuleAndFunction(context, + /* to_function = */ function, + /* to_module = */ module, + /* from_function = */ new_function, + /* from_module = */ new_module) + .failed()) { return signalPassFailure(); } } |