diff options
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 45a1b37..8aef8fd 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -210,7 +210,8 @@ private: const TosaSerializationOperatorBuilder &op_builder, mlir::Operation &op); TosaSerializationTensor * BuildTosaSerializationVariableTensor(mlir::RankedTensorType tensor_type, - const std::string &name); + const std::string &name, + const std::string &variable_mlir_name); TosaSerializationTensor * BuildTosaSerializationTensor(mlir::Value val, const std::string &name); @@ -305,7 +306,6 @@ static std::vector<T> getDenseI8ArrayAttr(mlir::Attribute attr) { std::string TosaSerializationOperatorBuilder::GetVariableTensorName( mlir::Operation *op) const { - mlir::Attribute variable_op_name_attr = op->getAttr("name"); std::string variable_tensor_mlir_name = op->getAttr("name").cast<mlir::StringAttr>().getValue().str(); @@ -1766,8 +1766,13 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( .cast<mlir::TypeAttr>() .getValue() .cast<mlir::RankedTensorType>(); + + std::string variable_mlir_name = + op->getAttr("name").cast<mlir::StringAttr>().getValue().str(); + ser_tensor = BuildTosaSerializationVariableTensor( - tensor_type /* tensor_type */, pair.first /* flatbuffer name */); + tensor_type /* tensor_type */, pair.first /* flatbuffer name */, + variable_mlir_name); if (!ser_tensor) { llvm::errs() << "ERROR: Failed to build TosaSerializationTensor\n"; return mlir::failure(); @@ -1897,7 +1902,8 @@ TosaSerializationBlockBuilder::BuildTosaSerializationOperator( TosaSerializationTensor * TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor( - mlir::RankedTensorType tensor_type, const std::string &name) { + mlir::RankedTensorType tensor_type, const std::string &name, + const std::string &variable_mlir_name) { // If tensor already created before, use that tensor directly, create a new // one otherwise TosaSerializationTensor *ts = ser_block->GetTensorByName(name); @@ -1912,7 +1918,8 @@ TosaSerializationBlockBuilder::BuildTosaSerializationVariableTensor( ts = new TosaSerializationTensor(name, shape, type, std::vector<uint8_t>(), /* is_variable = */ true, - /* is_unranked = */ false); + /* is_unranked = */ false, + variable_mlir_name); return ts; } |