aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-05-05 18:04:36 +0000
committerTai Ly <tai.ly@arm.com>2023-05-06 21:49:39 +0000
commit04992e45ad982ce34edacea9f9275ca122e47a75 (patch)
tree6843f96a84d9e07f220117949bd9faff64bfd4c3
parent8ffce6d63372ea93f3e060571137bce11d4735d8 (diff)
downloadtosa_mlir_translator-04992e45ad982ce34edacea9f9275ca122e47a75.tar.gz
Add BuildMlirFromTosaFile API
This BuildMlirFromTosaFile API deserializes a tosa fbs or json file and returns a mlir module as OwningOpRef<ModuleOp> This also refactors the existing deserialization passes to use the new API, and then copy the deserialized module's main function into the existing function, and copy all attributes of the new function and new module into the existing function and module. This allows testing of the new API by running deserialization passes. Here is an example showing the attributes on the deserialized module and functions: module attributes {tf_saved_model.semantics, tosa.description = "Tosa FBS Converted", tosa.fbs_version = "0.70.0d"} { func.func @main(%arg0: tensor<1x256x256x3xui8> {tf_saved_model.index_path = ["input_0"]}) -> (tensor<1x1x17x3xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "tosa_de serialized_input_0:0", outputs = "tosa_deserialized_output_0:0"}, tf_saved_model.exported_names = ["tosa_deserialized"]} { Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ia6c0202ef43ce5d37788cd459ed7c3f8424dd619
-rw-r--r--include/DeserializationPasses.h9
-rw-r--r--src/TosaDeserialize.cpp213
2 files changed, 199 insertions, 23 deletions
diff --git a/include/DeserializationPasses.h b/include/DeserializationPasses.h
index 1bc195a..1a38814 100644
--- a/include/DeserializationPasses.h
+++ b/include/DeserializationPasses.h
@@ -19,6 +19,8 @@
#include <memory>
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
+#include "mlir/IR/BuiltinOps.h" // from @llvm-project
+#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
namespace mlir {
@@ -27,6 +29,13 @@ namespace tosa {
std::unique_ptr<Pass> createTosaDeserializePass();
std::unique_ptr<Pass> createTosaDeserializeJSONPass();
+// deserializes a tosa file and return an mlir module
+// if file_is_fbs is true, then treat file_name as a tosa flatbuffer file
+// otherwise, treat file_name as a tosa json file
+mlir::OwningOpRef<mlir::ModuleOp>
+BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context,
+ bool file_is_fbs = true);
+
#define GEN_PASS_REGISTRATION
#define GEN_PASS_CLASSES
#include "include/DeserializationPasses.h.inc"
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 495d6f0..196c8f6 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -45,6 +45,13 @@ llvm::cl::opt<std::string> tosa_deserialize_schema(
"tosa-deserialize-schema", llvm::cl::desc("<tosa flatbuffer schema file>"),
llvm::cl::init(""), llvm::cl::value_desc("filename"));
+const std::string kDefaultExportedName = "tosa_deserialized";
+const std::string kDefaultInputPrefix = "input_";
+const std::string kDefaultOutputPrefix = "output_";
+const std::string kDefaultFBSDescription = "Tosa FBS Converted";
+const std::string kDefaultJSONDescription = "Tosa JSON Converted";
+const std::string kMainFunctionName = "main";
+
namespace {
// construct tensor type from dtype and shape of TosaSerializationTensor
@@ -1418,9 +1425,10 @@ mlir::LogicalResult TosaMlirRegionBuilder::BuildAllBlocksInRegion(
return mlir::success();
}
-mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func,
- mlir::MLIRContext& context,
- tosa::TosaSerializationHandler& tsh) {
+mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp &func,
+ mlir::MLIRContext &context,
+ tosa::TosaSerializationHandler &tsh,
+ std::vector<mlir::Value> &main_returns) {
mlir::Region* main_region = func.getCallableRegion();
if (!main_region) {
@@ -1431,7 +1439,6 @@ mlir::LogicalResult buildTosaMlir(mlir::func::FuncOp& func,
TosaSerializationRegion* ser_main_region = tsh.GetRegions().front();
auto loc = func.getLoc();
- std::vector<mlir::Value> main_returns;
main_region->takeBody(*main_region); // empty old func body
auto op_builder = mlir::OpBuilder(func.getBody());
@@ -1465,26 +1472,179 @@ mlir::LogicalResult loadTosaSchema(tosa::TosaSerializationHandler& tsh) {
return mlir::success();
}
+namespace {
+
+mlir::NamedAttribute DefaultEntryFuncitonAttr(mlir::Builder &builder,
+ bool is_input, int count) {
+ std::string names;
+ for (int i = 0; i < count; i++) {
+ std::string name = kDefaultExportedName + "_";
+ name += (is_input ? kDefaultInputPrefix : kDefaultOutputPrefix);
+ name += std::to_string(i) + ":0";
+ if (i > 0) {
+ names += ",";
+ }
+ names += name;
+ }
+ return builder.getNamedAttr((is_input ? "inputs" : "outputs"),
+ builder.getStringAttr(names));
+}
+
+// erase function attrs and empty function region'd body
+void ResetFunction(mlir::func::FuncOp &function, mlir::MLIRContext &context) {
+ function->setAttrs(mlir::DictionaryAttr::get(&context, {}));
+ mlir::Region *main_region = function.getCallableRegion();
+ main_region->takeBody(*main_region);
+}
+
+// replace attrs and body of @a to_function and its parent module
+// by @a from_module and its "main" function
+mlir::LogicalResult CloneIntoModuleAndFunction(
+ mlir::MLIRContext &context, mlir::func::FuncOp &to_function,
+ mlir::ModuleOp &to_module, mlir::func::FuncOp &from_function,
+ mlir::ModuleOp &from_module) {
+ // copy all attrs from new_module to module
+ to_module->setAttrs(from_module->getAttrDictionary());
+ // erase attrs and body of function
+ ResetFunction(to_function, context);
+ // clone new_func attrs and region into function
+ mlir::IRMapping mapping;
+ from_function.cloneInto(to_function, mapping);
+ return mlir::success();
+}
+
+} // namespace
+
namespace mlir {
namespace tosa {
+mlir::OwningOpRef<mlir::ModuleOp>
+BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context,
+ bool file_is_fbs) {
+ TosaSerializationHandler tsh;
+ if (file_is_fbs) {
+ if (tsh.LoadFileTosaFlatbuffer(file_name)) {
+ llvm::errs() << "Fail to load TOSA file " << file_name << "\n";
+ return nullptr;
+ }
+ } else {
+ // must load tosa schema before loading json file
+ if (loadTosaSchema(tsh).failed()) {
+ return nullptr;
+ }
+ if (tsh.LoadFileJson(file_name)) {
+ llvm::errs() << "Fail to load TOSA JSON file " << file_name << "\n";
+ return nullptr;
+ }
+ }
+
+ // create new module
+ auto base_loc = mlir::FileLineColLoc::get(context, file_name, 0, 0);
+ auto module = mlir::ModuleOp::create(base_loc);
+
+ // set module attributes
+ const auto &tosa_version = tsh.GetVersion().to_string();
+ std::string tosa_description =
+ file_is_fbs ? kDefaultFBSDescription : kDefaultJSONDescription;
+ auto builder = mlir::Builder(context);
+ module->setAttr("tosa.fbs_version", builder.getStringAttr(tosa_version));
+ module->setAttr("tosa.description", builder.getStringAttr(tosa_description));
+ module->setAttr("tf_saved_model.semantics", mlir::UnitAttr::get(context));
+
+ // construct function with input and return types
+ llvm::SmallVector<mlir::Type, 2> ret_types;
+ llvm::SmallVector<mlir::Type, 4> input_types;
+ auto func_type = builder.getFunctionType(input_types, ret_types);
+ auto func_loc =
+ mlir::NameLoc::get(builder.getStringAttr(kMainFunctionName), base_loc);
+ auto func = mlir::func::FuncOp::create(func_loc, kMainFunctionName, func_type,
+ /* attrs= */ {});
+ func.addEntryBlock();
+
+ // deserialize tosa fbs into function
+ std::vector<mlir::Value> main_returns;
+ if (buildTosaMlir(func, *context, tsh, main_returns).failed()) {
+ llvm::errs() << "Failed to deserialize flatbuffer "
+ << tosa_deserialize_filename << "\n";
+ return nullptr;
+ }
+ auto main_args = func.getCallableRegion()->getArguments();
+ // extract function input types
+ for (auto arg : main_args) {
+ input_types.push_back(arg.getType());
+ }
+ // extract function return types
+ for (auto ret : main_returns) {
+ ret_types.push_back(ret.getType());
+ }
+ // set function type with full input and return types
+ func_type = builder.getFunctionType(input_types, ret_types);
+ func.setType(func_type);
+
+ // set function attributes
+ llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
+ if (!input_types.empty()) {
+ attributes.push_back(DefaultEntryFuncitonAttr(
+ builder, /* is_input = */ true, /* count = */ input_types.size()));
+ for (int i = 0; i < input_types.size(); i++) {
+ std::string input_i = kDefaultInputPrefix + std::to_string(i);
+ func.setArgAttr(i, "tf_saved_model.index_path",
+ mlir::ArrayAttr::get(
+ context, {mlir::StringAttr::get(context, input_i)}));
+ }
+ }
+ if (!ret_types.empty()) {
+ attributes.push_back(DefaultEntryFuncitonAttr(
+ builder, /* is_input = */ false, /* count = */ ret_types.size()));
+ for (int i = 0; i < ret_types.size(); i++) {
+ std::string output_i = kDefaultOutputPrefix + std::to_string(i);
+ func.setResultAttr(
+ i, "tf_saved_model.index_path",
+ mlir::ArrayAttr::get(context,
+ {mlir::StringAttr::get(context, output_i)}));
+ }
+ }
+ func->setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
+ func->setAttr(
+ "tf_saved_model.exported_names",
+ mlir::ArrayAttr::get(
+ context, {mlir::StringAttr::get(context, kDefaultExportedName)}));
+
+ // add func to module
+ module.push_back(std::move(func));
+ return mlir::OwningOpRef<mlir::ModuleOp>(module);
+}
+
namespace {
class TosaDeserialize : public TosaDeserializationPassBase<TosaDeserialize> {
public:
void runOnOperation() final {
- TosaSerializationHandler tsh;
- if (tsh.LoadFileTosaFlatbuffer(tosa_deserialize_filename.c_str())) {
- llvm::errs() << "Fail to load TOSA file " << tosa_deserialize_filename << "\n";
- return signalPassFailure();
- }
-
auto function = getOperation();
auto& context = getContext();
- if (buildTosaMlir(function, context, tsh).failed()) {
- llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n";
+ auto new_module_ref = BuildMlirFromTosaFile(
+ tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ true);
+ if (!new_module_ref) {
+ return signalPassFailure();
+ }
+
+ mlir::ModuleOp new_module = *new_module_ref;
+ auto builder = mlir::Builder(&context);
+ auto module = function->getParentOfType<mlir::ModuleOp>();
+ auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>(
+ builder.getStringAttr(kMainFunctionName));
+ if (!new_function) {
+ llvm::errs() << "Failed to find main function in deserialized module\n";
+ return signalPassFailure();
+ }
+ if (CloneIntoModuleAndFunction(context,
+ /* to_function = */ function,
+ /* to_module = */ module,
+ /* from_function = */ new_function,
+ /* from_module = */ new_module)
+ .failed()) {
return signalPassFailure();
}
}
@@ -1494,23 +1654,30 @@ class TosaDeserializeJSON
: public TosaDeserializationJSONPassBase<TosaDeserializeJSON> {
public:
void runOnOperation() final {
- TosaSerializationHandler tsh;
+ auto function = getOperation();
+ auto &context = getContext();
- // must load tosa schema before loading json file
- if (loadTosaSchema(tsh).failed()) {
+ auto new_module_ref = BuildMlirFromTosaFile(
+ tosa_deserialize_filename.c_str(), &context, /* file_is_fbs = */ false);
+ if (!new_module_ref) {
return signalPassFailure();
}
- if (tsh.LoadFileJson(tosa_deserialize_filename.c_str())) {
- llvm::errs() << "Fail to load TOSA JSON file " << tosa_deserialize_filename << "\n";
+ mlir::ModuleOp new_module = *new_module_ref;
+ auto builder = mlir::Builder(&context);
+ auto module = function->getParentOfType<mlir::ModuleOp>();
+ auto new_function = new_module.lookupSymbol<mlir::func::FuncOp>(
+ builder.getStringAttr(kMainFunctionName));
+ if (!new_function) {
+ llvm::errs() << "Failed to find main function in deserialized module\n";
return signalPassFailure();
}
-
- auto function = getOperation();
- auto& context = getContext();
-
- if (buildTosaMlir(function, context, tsh).failed()) {
- llvm::errs() << "Failed to deserialize flatbuffer " << tosa_deserialize_filename << "\n";
+ if (CloneIntoModuleAndFunction(context,
+ /* to_function = */ function,
+ /* to_module = */ module,
+ /* from_function = */ new_function,
+ /* from_module = */ new_module)
+ .failed()) {
return signalPassFailure();
}
}