From 28026df342e035d71907c5e6cf1ee29f3974afea Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Fri, 22 Sep 2023 19:30:11 +0000 Subject: Implement deserialization of stateful ops This patch implements deserialization of variable, variable.read and variable.write ops. The variable ops are deserialized before the function. Signed-off-by: Tai Ly Change-Id: I029bdf087576e97cacab469386863e6d1baf855c --- src/TosaSerialize.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'src/TosaSerialize.cpp') diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 45a1b37..8aef8fd 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -210,7 +210,8 @@ private: const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op); TosaSerializationTensor * BuildTosaSerializationVariableTensor(mlir::RankedTensorType tensor_type, - const std::string &name); + const std::string &name, + const std::string &variable_mlir_name); TosaSerializationTensor * BuildTosaSerializationTensor(mlir::Value val, const std::string &name); @@ -305,7 +306,6 @@ static std::vector getDenseI8ArrayAttr(mlir::Attribute attr) { std::string TosaSerializationOperatorBuilder::GetVariableTensorName( mlir::Operation *op) const { - mlir::Attribute variable_op_name_attr = op->getAttr("name"); std::string variable_tensor_mlir_name = op->getAttr("name").cast().getValue().str(); @@ -1766,8 +1766,13 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( .cast() .getValue() .cast(); + + std::string variable_mlir_name = + op->getAttr("name").cast().getValue().str(); + ser_tensor = BuildTosaSerializationVariableTensor( - tensor_type /* tensor_type */, pair.first /* flatbuffer name */); + tensor_type /* tensor_type */, pair.first /* flatbuffer name */, + variable_mlir_name); if (!ser_tensor) { llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; return mlir::failure(); @@ -1897,7 +1902,8 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator( TosaSerializationTensor * TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor( - mlir::RankedTensorType tensor_type, const std::string &name) { + mlir::RankedTensorType tensor_type, const std::string &name, + const std::string &variable_mlir_name) { // If tensor already created before, use that tensor directly, create a new // one otherwise TosaSerializationTensor *ts = ser_block->GetTensorByName(name); @@ -1912,7 +1918,8 @@ TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor( ts = new TosaSerializationTensor(name, shape, type, std::vector(), /* is_variable = */ true, - /* is_unranked = */ false); + /* is_unranked = */ false, + variable_mlir_name); return ts; } -- cgit v1.2.1