From 28026df342e035d71907c5e6cf1ee29f3974afea Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Fri, 22 Sep 2023 19:30:11 +0000 Subject: Implement deserialization of stateful ops This patch implements deserialization of variable, variable.read and variable.write ops. The variable ops are deserialized before the function. Signed-off-by: Tai Ly Change-Id: I029bdf087576e97cacab469386863e6d1baf855c --- src/TosaDeserialize.cpp | 346 ++++++++++++++++++++++++++++++------------ src/TosaSerialize.cpp | 17 ++- third_party/serialization_lib | 2 +- 3 files changed, 260 insertions(+), 105 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 79f0c78..031c57f 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -54,6 +54,51 @@ const std::string kMainFunctionName = "main"; namespace { +// a global map from flatbuffer variable names to serialized tensors +std::unordered_map variable_tensor_map; + +void RegisterVariableTensor(TosaSerializationTensor *ts) { + assert(ts->GetVariable()); + // insert variable tensor ts only if not already present + variable_tensor_map.insert({ts->GetName(), ts}); +} + +bool IsVariableTensor(const std::string flatbuffer_tensor_name) { + return variable_tensor_map.count(flatbuffer_tensor_name); +} + +// return the variable name corresponding to flatbuffer_tensor_name +const std::string GetVariableTensorName(TosaSerializationTensor *ts) { + assert(ts->GetVariable()); + const auto name = ts->GetVariableName(); + if (name == "") { + // for legacy flatbuffers which may not have variable_name fields + return ts->GetName(); + } + return name; +} + +// return the variable name corresponding to flatbuffer_tensor_name +const std::string +GetVariableTensorName(const std::string flatbuffer_tensor_name) { + if (!IsVariableTensor(flatbuffer_tensor_name)) { + llvm::errs() << "ERROR: Variable tensor " << flatbuffer_tensor_name + << " is not found in variable_tensor_map"; + return ""; + } + return GetVariableTensorName(variable_tensor_map[flatbuffer_tensor_name]); +} + +bool IsVariableReadOp(TosaSerializationOperator *op) { + return (op->GetOp() == tosa::Op::Op_IDENTITY) && + IsVariableTensor(op->GetInputTensorNames()[0]); +} + +bool IsVariableWriteOp(TosaSerializationOperator *op) { + return (op->GetOp() == tosa::Op::Op_IDENTITY) && + IsVariableTensor(op->GetOutputTensorNames()[0]); +} + // construct tensor type from dtype and shape of TosaSerializationTensor mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, TosaSerializationTensor *ts, @@ -108,6 +153,107 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, return mlir::success(); } +mlir::DenseElementsAttr +ConstructConstAttr(const mlir::RankedTensorType &output_type, + TosaSerializationTensor *ts, const std::string &op_name) { + const auto &data = ts->GetData(); + auto &shape = ts->GetShape(); + // compute output data size + uint32_t out_size = 1; + for (const auto dim : shape) { + out_size *= dim; + } + mlir::DenseElementsAttr value_attr; + switch (ts->GetDtype()) { + case DType_FP32: { + std::vector float_data; + TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data); + value_attr = + mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(float_data)); + break; + } + case DType_INT4: { + std::vector int4_data; + TosaSerializationHandler::ConvertU8toI4(data, out_size, int4_data); + value_attr = + mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data)); + break; + } + case DType_INT8: { + std::vector int8_data; + TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data); + value_attr = + mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data)); + break; + } + case DType_INT16: { + std::vector int16_data; + TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data); + value_attr = + mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int16_data)); + break; + } + case DType_INT32: { + std::vector int32_data; + TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data); + value_attr = + mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int32_data)); + break; + } + case DType_INT48: { + std::vector int48_data; + TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data); + std::vector apint_data; + for (const auto v : int48_data) { + mlir::APInt apint_value(48, static_cast(v), + /* isSigned = */ false); + apint_data.push_back(apint_value); + } + value_attr = + mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(apint_data)); + break; + } + case DType_BOOL: { + std::vector bool_data; + TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data); + llvm::SmallVector bool_values(bool_data.begin(), bool_data.end()); + value_attr = mlir::DenseElementsAttr::get(output_type, bool_values); + break; + } + default: { + llvm::errs() << "ERROR: " << op_name + << " contains unsupported element type\n"; + return nullptr; + } + } + + return value_attr; +} + +mlir::LogicalResult ConstructVariableOps(mlir::ModuleOp &module) { + if (variable_tensor_map.empty()) { + return mlir::success(); + } + auto loc = module.getLoc(); + auto op_builder = mlir::OpBuilder(module.getBodyRegion()); + for (auto [flatbuffer_name, ts] : variable_tensor_map) { + auto name = GetVariableTensorName(ts); + mlir::RankedTensorType type; + if (BuildTensorType(&op_builder, ts, type).failed()) { + return mlir::failure(); + } + + mlir::Attribute value_attr = nullptr; + if (!ts->GetData().empty()) { + value_attr = ConstructConstAttr(type, ts, name); + } + op_builder.create(loc, llvm::StringRef(name), type, + value_attr); + } + + return mlir::success(); +} + template mlir::DenseElementsAttr BuildDenseI8ElementsAttr(mlir::OpBuilder *op_builder, const std::vector &values) { @@ -221,6 +367,13 @@ public: template std::vector build(TosaSerializationOperator *op) const; + std::vector BuildVariableOp(TosaSerializationOperator *op) const; + + std::vector + BuildVariableReadOp(TosaSerializationOperator *op) const; + + void BuildVariableWriteOp(TosaSerializationOperator *op) const; + std::string get_string(TosaSerializationOperator *op) const { std::string op_string; op_string += "operator opcode="; @@ -334,6 +487,37 @@ std::vector TosaMlirOperatorBuilder::build( return std::vector({mlir_op->getResult(0)}); } +std::vector TosaMlirOperatorBuilder::BuildVariableReadOp( + TosaSerializationOperator *op) const { + auto input_tensor_name = op->GetInputTensorNames()[0]; + auto output_tensor_name = op->GetOutputTensorNames()[0]; + + assert(IsVariableTensor(input_tensor_name)); + + auto variable_name = GetVariableTensorName(input_tensor_name); + mlir::RankedTensorType output_type = tensor_type_map->at(output_tensor_name); + assert(op->GetAttributeType() == + Attribute_NONE); // double check that there is no attribute + mlir::Operation *mlir_op = op_builder->create( + loc, output_type, llvm::StringRef(variable_name)); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + +void TosaMlirOperatorBuilder::BuildVariableWriteOp( + TosaSerializationOperator *op) const { + auto input_tensor_name = op->GetInputTensorNames()[0]; + auto output_tensor_name = op->GetOutputTensorNames()[0]; + + assert(IsVariableTensor(output_tensor_name)); + auto variable_name = GetVariableTensorName(output_tensor_name); + mlir::Value input_val = tensor_map->at(input_tensor_name); + + mlir::Operation *mlir_op = op_builder->create( + loc, llvm::StringRef(variable_name), input_val); + block->push_back(mlir_op); +} + template std::vector TosaMlirOperatorBuilder::BuildEwiseUnaryOp( TosaSerializationOperator *op) const { @@ -460,73 +644,8 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { const auto &output_name = op->GetOutputTensorNames()[0]; mlir::RankedTensorType output_type = tensor_type_map->at(output_name); TosaSerializationTensor *ts = ser_block->GetTensorByName(output_name); - const auto &data = ts->GetData(); - auto &shape = ts->GetShape(); - // compute output data size - uint32_t out_size = 1; - for (const auto dim : shape) { - out_size *= dim; - } - mlir::DenseElementsAttr value_attr; - switch (ts->GetDtype()) { - case DType_FP32: { - std::vector float_data; - TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data); - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(float_data)); - break; - } - case DType_INT4: { - std::vector int4_data; - TosaSerializationHandler::ConvertU8toI4(data, out_size, int4_data); - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data)); - break; - } - case DType_INT8: { - std::vector int8_data; - TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data); - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data)); - break; - } - case DType_INT16: { - std::vector int16_data; - TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data); - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int16_data)); - break; - } - case DType_INT32: { - std::vector int32_data; - TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data); - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int32_data)); - break; - } - case DType_INT48: { - std::vector int48_data; - TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data); - std::vector apint_data; - for (const auto v : int48_data) { - mlir::APInt apint_value(48, static_cast(v), - /* isSigned = */ false); - apint_data.push_back(apint_value); - } - value_attr = - mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(apint_data)); - break; - } - case DType_BOOL: { - std::vector bool_data; - TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data); - llvm::SmallVector bool_values(bool_data.begin(), bool_data.end()); - value_attr = mlir::DenseElementsAttr::get(output_type, bool_values); - break; - } - default: - llvm::errs() << "ERROR: " << get_string(op) - << " contains unsupported element type\n"; + auto value_attr = ConstructConstAttr(output_type, ts, get_string(op)); + if (!value_attr) { return {}; } mlir::Operation *mlir_op = @@ -1420,6 +1539,9 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( this, &tensor_map, &tensor_type_map); for (auto ts : ser_block->GetTensors()) { + if (ts->GetVariable()) { + RegisterVariableTensor(ts); + } mlir::RankedTensorType type; if (BuildTensorType(op_builder, ts, type).failed()) { return mlir::failure(); @@ -1432,27 +1554,6 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( } } - // Update operator_queue with operators whose inputs are all built - auto queue_ready_operators = [&]() { - // note: it is important to queue in order of original operators - // because an operator input may be defined in upper value scopes - for (auto consumer_op : ser_block->GetOperators()) { - if (operator_built.count(consumer_op)) { - continue; - } - bool all_inputs_ready = true; - for (const auto &input_name : consumer_op->GetInputTensorNames()) { - if (!tensor_map.count(input_name)) { - all_inputs_ready = false; - break; - } - } - if (all_inputs_ready) { - operator_queue.push(consumer_op); - } - } - }; - // Initialize tensor_map/operator_queue based on block input arguments for (const std::string &block_input_name : ser_block->GetInputs()) { mlir::Type type = tensor_type_map[block_input_name]; @@ -1470,11 +1571,7 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( tensor_map[block_input_name] = input_value; } - queue_ready_operators(); - - while (!operator_queue.empty()) { - TosaSerializationOperator *op = operator_queue.front(); - operator_queue.pop(); + for (auto op : ser_block->GetOperators()) { // skip if operator has been built if (operator_built.count(op)) { @@ -1485,7 +1582,10 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( operator_built.insert(op); std::vector output_values; - if (false) { + if (IsVariableReadOp(op)) { + output_values = tosa_op_builder.BuildVariableReadOp(op); + } else if (IsVariableWriteOp(op)) { + tosa_op_builder.BuildVariableWriteOp(op); } #define DEF_SCHEMA_OPERATOR(SCHEMA_OP_NAME) \ else if (op->GetOp() == Op_##SCHEMA_OP_NAME) { \ @@ -1499,6 +1599,12 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( return mlir::failure(); } + if (IsVariableWriteOp(op)) { + // the sanity checking below does not apply for variable write op because + // it has no output tensors whereas the original identity op has + continue; + } + // Sanity check if number of built mlir::Value is expected if (op->GetOutputTensorNames().size() != output_values.size()) { llvm::errs() << "ERROR: number of built mlir::Value is not matching " @@ -1516,8 +1622,6 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( } tensor_map[op_output_name] = output_values[i]; } - // look for any more ready consumers - queue_ready_operators(); } // Construct return values @@ -1533,8 +1637,15 @@ mlir::LogicalResult TosaMlirBlockBuilder::BuildAllOpsInBlock( return_operands.push_back(output_value); return_values.push_back(output_value); } - auto terminator_op = - op_builder->create(loc, return_operands); + mlir::Operation *terminator_op; + auto parent_op = block->getParentOp(); + if (mlir::isa(parent_op)) { + terminator_op = + op_builder->create(loc, return_operands); + } else { + terminator_op = + op_builder->create(loc, return_operands); + } block->push_back(terminator_op); // need topological sorting? @@ -1624,7 +1735,20 @@ mlir::NamedAttribute DefaultEntryFuncitonAttr(mlir::Builder &builder, builder.getStringAttr(names)); } -// erase function attrs and empty function region'd body +// erase all ops in block except for FuncOp +void ClearNonFuncOps(mlir::Block *block) { + std::vector to_delete; + for (auto &op : block->getOperations()) { + if (!mlir::isa(op)) { + to_delete.push_back(&op); + } + } + for (mlir::Operation *op : to_delete) { + op->erase(); + } +} + +// erase function attrs and empty function region's body void ResetFunction(mlir::func::FuncOp &function, mlir::MLIRContext &context) { function->setAttrs(mlir::DictionaryAttr::get(&context, {})); mlir::Region *main_region = function.getCallableRegion(); @@ -1637,6 +1761,9 @@ mlir::LogicalResult CloneIntoModuleAndFunction( mlir::MLIRContext &context, mlir::func::FuncOp &to_function, mlir::ModuleOp &to_module, mlir::func::FuncOp &from_function, mlir::ModuleOp &from_module) { + auto from_block = from_function.getOperation()->getBlock(); + auto to_block = to_function.getOperation()->getBlock(); + ClearNonFuncOps(to_block); // copy all attrs from new_module to module to_module->setAttrs(from_module->getAttrDictionary()); // erase attrs and body of function @@ -1644,6 +1771,22 @@ mlir::LogicalResult CloneIntoModuleAndFunction( // clone new_func attrs and region into function mlir::IRMapping mapping; from_function.cloneInto(to_function, mapping); + + // copy variable ops in from_block to to_block + // collect variable ops in from_block in reverse order + std::vector variable_ops; + for (mlir::Operation &op : *from_block) { + if (mlir::isa(op)) { + variable_ops.push_back(&op); + } + } + auto cloneOptions = + mlir::Operation::CloneOptions::all().cloneRegions(false).cloneOperands( + false); + for (auto iter = variable_ops.rbegin(); iter != variable_ops.rend(); iter++) { + auto op = *iter; + to_block->push_front(op->clone(mapping, cloneOptions)); + } return mlir::success(); } @@ -1745,6 +1888,11 @@ BuildMlirFromTosaFile(const char *file_name, mlir::MLIRContext *context, mlir::ArrayAttr::get( context, {mlir::StringAttr::get(context, kDefaultExportedName)})); + // deserialize variable ops in the new module just before adding func op + if (ConstructVariableOps(module).failed()) { + return nullptr; + } + // add func to module module.push_back(std::move(func)); return mlir::OwningOpRef(module); diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 45a1b37..8aef8fd 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -210,7 +210,8 @@ private: const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op); TosaSerializationTensor * BuildTosaSerializationVariableTensor(mlir::RankedTensorType tensor_type, - const std::string &name); + const std::string &name, + const std::string &variable_mlir_name); TosaSerializationTensor * BuildTosaSerializationTensor(mlir::Value val, const std::string &name); @@ -305,7 +306,6 @@ static std::vector getDenseI8ArrayAttr(mlir::Attribute attr) { std::string TosaSerializationOperatorBuilder::GetVariableTensorName( mlir::Operation *op) const { - mlir::Attribute variable_op_name_attr = op->getAttr("name"); std::string variable_tensor_mlir_name = op->getAttr("name").cast().getValue().str(); @@ -1766,8 +1766,13 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( .cast() .getValue() .cast(); + + std::string variable_mlir_name = + op->getAttr("name").cast().getValue().str(); + ser_tensor = BuildTosaSerializationVariableTensor( - tensor_type /* tensor_type */, pair.first /* flatbuffer name */); + tensor_type /* tensor_type */, pair.first /* flatbuffer name */, + variable_mlir_name); if (!ser_tensor) { llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; return mlir::failure(); @@ -1897,7 +1902,8 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator( TosaSerializationTensor * TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor( - mlir::RankedTensorType tensor_type, const std::string &name) { + mlir::RankedTensorType tensor_type, const std::string &name, + const std::string &variable_mlir_name) { // If tensor already created before, use that tensor directly, create a new // one otherwise TosaSerializationTensor *ts = ser_block->GetTensorByName(name); @@ -1912,7 +1918,8 @@ TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor( ts = new TosaSerializationTensor(name, shape, type, std::vector(), /* is_variable = */ true, - /* is_unranked = */ false); + /* is_unranked = */ false, + variable_mlir_name); return ts; } diff --git a/third_party/serialization_lib b/third_party/serialization_lib index c0a6030..5917fc7 160000 --- a/third_party/serialization_lib +++ b/third_party/serialization_lib @@ -1 +1 @@ -Subproject commit c0a60300951a59c33d2afaea0f6ca0889cabf340 +Subproject commit 5917fc7a9392da8fd1e8c68b2d00b89709a31584 -- cgit v1.2.1