aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-09-22 19:30:11 +0000
committerTai Ly <tai.ly@arm.com>2023-10-16 17:11:30 +0000
commit28026df342e035d71907c5e6cf1ee29f3974afea (patch)
treea29c6bb28380ead925405dbbeaa61f1d1bb74260 /src/TosaSerialize.cpp
parenta7d41ccfc92f0a00ee6b4e96f217eb8f27956b00 (diff)
downloadtosa_mlir_translator-28026df342e035d71907c5e6cf1ee29f3974afea.tar.gz
Implement deserialization of stateful ops
This patch implements deserialization of variable, variable.read and variable.write ops. The variable ops are deserialized before the function. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I029bdf087576e97cacab469386863e6d1baf855c
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp17
1 files changed, 12 insertions, 5 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 45a1b37..8aef8fd 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -210,7 +210,8 @@ private:
const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op);
TosaSerializationTensor *
BuildTosaSerializationVariableTensor(mlir::RankedTensorType tensor_type,
- const std::string &name);
+ const std::string &name,
+ const std::string &variable_mlir_name);
TosaSerializationTensor *
BuildTosaSerializationTensor(mlir::Value val, const std::string &name);
@@ -305,7 +306,6 @@ static std::vector<T> getDenseI8ArrayAttr(mlir::Attribute attr) {
std::string TosaSerializationOperatorBuilder::GetVariableTensorName(
mlir::Operation *op) const {
- mlir::Attribute variable_op_name_attr = op->getAttr("name");
std::string variable_tensor_mlir_name =
op->getAttr("name").cast<mlir::StringAttr>().getValue().str();
@@ -1766,8 +1766,13 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
.cast<mlir::TypeAttr>()
.getValue()
.cast<mlir::RankedTensorType>();
+
+ std::string variable_mlir_name =
+ op->getAttr("name").cast<mlir::StringAttr>().getValue().str();
+
ser_tensor = BuildTosaSerializationVariableTensor(
- tensor_type /* tensor_type */, pair.first /* flatbuffer name */);
+ tensor_type /* tensor_type */, pair.first /* flatbuffer name */,
+ variable_mlir_name);
if (!ser_tensor) {
llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n";
return mlir::failure();
@@ -1897,7 +1902,8 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator(
TosaSerializationTensor *
TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor(
- mlir::RankedTensorType tensor_type, const std::string &name) {
+ mlir::RankedTensorType tensor_type, const std::string &name,
+ const std::string &variable_mlir_name) {
// If tensor already created before, use that tensor directly, create a new
// one otherwise
TosaSerializationTensor *ts = ser_block->GetTensorByName(name);
@@ -1912,7 +1918,8 @@ TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor(
ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(),
/* is_variable = */ true,
- /* is_unranked = */ false);
+ /* is_unranked = */ false,
+ variable_mlir_name);
return ts;
}