diff options
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 0557520..3c95c75 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -23,8 +23,8 @@ #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tosa_serialization_handler.h" #include <functional> -#include <unordered_map> #include <map> +#include <unordered_map> // The namespace might be confusing here. We have mlir::tosa:: defined in MLIR // and tosa:: defined in serialization library @@ -205,6 +205,7 @@ private: TosaSerializationHandler *tsh; mlir::Region *region; std::unordered_map<mlir::Value, std::string> tensor_map; + std::unordered_map<mlir::Value, std::string> input_tensor_map; }; std::string @@ -1473,6 +1474,7 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( "TosaInput_" + std::to_string(input_tensor_index++); block->GetInputs().push_back(block_input_name); tensor_map[args] = block_input_name; + input_tensor_map[args] = block_input_name; } // Build tensor_map @@ -1492,8 +1494,10 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( for (auto val : op.getOperands()) { // Workaround to skip mlir::tensor::CastOp before return mlir::Operation *val_defining_op = val.getDefiningOp(); - if (llvm::isa<mlir::tensor::CastOp>(*val_defining_op)) - val = val_defining_op->getOperand(0); + if (val_defining_op) { + if (llvm::isa<mlir::tensor::CastOp>(*val_defining_op)) + val = val_defining_op->getOperand(0); + } // Sanity check. This mlir::Value should be built in map since graph // is DAG @@ -1501,11 +1505,20 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( llvm::errs() << "ERROR: Can't find built mlir::Value key.\n"; return mlir::failure(); } - std::string output_name = - "TosaOutput_" + std::to_string(output_tensor_index++); - tensor_map[val] = output_name; - block->GetOutputs().push_back(output_name); - return_values.push_back(val); + + // If returned value is block input, short-circuit the tensor name + // Otherwise, build a new output name and override the origin tensor + // name + if (input_tensor_map.find(val) != input_tensor_map.end()) { + block->GetOutputs().push_back(input_tensor_map[val]); + return_values.push_back(val); + } else { + std::string output_name = + "TosaOutput_" + std::to_string(output_tensor_index++); + tensor_map[val] = output_name; + block->GetOutputs().push_back(output_name); + return_values.push_back(val); + } } } } @@ -1513,8 +1526,8 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( // Build tensor // The tensor_map is sorted by hashed mlir::Value types. - // For serialization, sort tensors alphabetically by name for a deterministic - // and human-friendly ordering. + // For serialization, sort tensors alphabetically by name for a + // deterministic and human-friendly ordering. std::map<std::string, mlir::Value> tensor_name_sort; for (auto pair : tensor_map) tensor_name_sort[pair.second] = pair.first; |