From 3765888e5b2ca8c826edad0d9c90261573af4fce Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Tue, 4 Jan 2022 16:06:07 -0800 Subject: Short-circuit output tensor name with input tensor name if they're the same mlir::Value Signed-off-by: Kevin Cheng Change-Id: I578802841fbcacb67560b07b26170a9efa166d52 --- src/TosaSerialize.cpp | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 0557520..3c95c75 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -23,8 +23,8 @@ #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tosa_serialization_handler.h" #include -#include #include +#include // The namespace might be confusing here. We have mlir::tosa:: defined in MLIR // and tosa:: defined in serialization library @@ -205,6 +205,7 @@ private: TosaSerializationHandler *tsh; mlir::Region *region; std::unordered_map tensor_map; + std::unordered_map input_tensor_map; }; std::string @@ -1473,6 +1474,7 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( "TosaInput_" + std::to_string(input_tensor_index++); block->GetInputs().push_back(block_input_name); tensor_map[args] = block_input_name; + input_tensor_map[args] = block_input_name; } // Build tensor_map @@ -1492,8 +1494,10 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( for (auto val : op.getOperands()) { // Workaround to skip mlir::tensor::CastOp before return mlir::Operation *val_defining_op = val.getDefiningOp(); - if (llvm::isa(*val_defining_op)) - val = val_defining_op->getOperand(0); + if (val_defining_op) { + if (llvm::isa(*val_defining_op)) + val = val_defining_op->getOperand(0); + } // Sanity check. This mlir::Value should be built in map since graph // is DAG @@ -1501,11 +1505,20 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( llvm::errs() << "ERROR: Can't find built mlir::Value key.\n"; return mlir::failure(); } - std::string output_name = - "TosaOutput_" + std::to_string(output_tensor_index++); - tensor_map[val] = output_name; - block->GetOutputs().push_back(output_name); - return_values.push_back(val); + + // If returned value is block input, short-circuit the tensor name + // Otherwise, build a new output name and override the origin tensor + // name + if (input_tensor_map.find(val) != input_tensor_map.end()) { + block->GetOutputs().push_back(input_tensor_map[val]); + return_values.push_back(val); + } else { + std::string output_name = + "TosaOutput_" + std::to_string(output_tensor_index++); + tensor_map[val] = output_name; + block->GetOutputs().push_back(output_name); + return_values.push_back(val); + } } } } @@ -1513,8 +1526,8 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion( // Build tensor // The tensor_map is sorted by hashed mlir::Value types. - // For serialization, sort tensors alphabetically by name for a deterministic - // and human-friendly ordering. + // For serialization, sort tensors alphabetically by name for a + // deterministic and human-friendly ordering. std::map tensor_name_sort; for (auto pair : tensor_map) tensor_name_sort[pair.second] = pair.first; -- cgit v1.2.1