diff options
author | Tai Ly <tai.ly@arm.com> | 2023-09-22 19:30:11 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-10-16 17:11:30 +0000 |
commit | 28026df342e035d71907c5e6cf1ee29f3974afea (patch) | |
tree | a29c6bb28380ead925405dbbeaa61f1d1bb74260 /src/TosaSerialize.cpp | |
parent | a7d41ccfc92f0a00ee6b4e96f217eb8f27956b00 (diff) | |
download | tosa_mlir_translator-28026df342e035d71907c5e6cf1ee29f3974afea.tar.gz |
Implement deserialization of stateful ops
This patch implements deserialization of variable,
variable.read and variable.write ops.
The variable ops are deserialized before the function.
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I029bdf087576e97cacab469386863e6d1baf855c
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 45a1b37..8aef8fd 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -210,7 +210,8 @@ private: const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op); TosaSerializationTensor * BuildTosaSerializationVariableTensor(mlir::RankedTensorType tensor_type, - const std::string &name); + const std::string &name, + const std::string &variable_mlir_name); TosaSerializationTensor * BuildTosaSerializationTensor(mlir::Value val, const std::string &name); @@ -305,7 +306,6 @@ static std::vector<T> getDenseI8ArrayAttr(mlir::Attribute attr) { std::string TosaSerializationOperatorBuilder::GetVariableTensorName( mlir::Operation *op) const { - mlir::Attribute variable_op_name_attr = op->getAttr("name"); std::string variable_tensor_mlir_name = op->getAttr("name").cast<mlir::StringAttr>().getValue().str(); @@ -1766,8 +1766,13 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( .cast<mlir::TypeAttr>() .getValue() .cast<mlir::RankedTensorType>(); + + std::string variable_mlir_name = + op->getAttr("name").cast<mlir::StringAttr>().getValue().str(); + ser_tensor = BuildTosaSerializationVariableTensor( - tensor_type /* tensor_type */, pair.first /* flatbuffer name */); + tensor_type /* tensor_type */, pair.first /* flatbuffer name */, + variable_mlir_name); if (!ser_tensor) { llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; return mlir::failure(); @@ -1897,7 +1902,8 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator( TosaSerializationTensor * TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor( - mlir::RankedTensorType tensor_type, const std::string &name) { + mlir::RankedTensorType tensor_type, const std::string &name, + const std::string &variable_mlir_name) { // If tensor already created before, use that tensor directly, create a new // one otherwise TosaSerializationTensor *ts = ser_block->GetTensorByName(name); @@ -1912,7 +1918,8 @@ TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor( ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(), /* is_variable = */ true, - /* is_unranked = */ false); + /* is_unranked = */ false, + variable_mlir_name); return ts; } |