aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp33
1 files changed, 23 insertions, 10 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 0557520..3c95c75 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -23,8 +23,8 @@
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tosa_serialization_handler.h"
#include <functional>
-#include <unordered_map>
#include <map>
+#include <unordered_map>
// The namespace might be confusing here. We have mlir::tosa:: defined in MLIR
// and tosa:: defined in serialization library
@@ -205,6 +205,7 @@ private:
TosaSerializationHandler *tsh;
mlir::Region *region;
std::unordered_map<mlir::Value, std::string> tensor_map;
+ std::unordered_map<mlir::Value, std::string> input_tensor_map;
};
std::string
@@ -1473,6 +1474,7 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion(
"TosaInput_" + std::to_string(input_tensor_index++);
block->GetInputs().push_back(block_input_name);
tensor_map[args] = block_input_name;
+ input_tensor_map[args] = block_input_name;
}
// Build tensor_map
@@ -1492,8 +1494,10 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion(
for (auto val : op.getOperands()) {
// Workaround to skip mlir::tensor::CastOp before return
mlir::Operation *val_defining_op = val.getDefiningOp();
- if (llvm::isa<mlir::tensor::CastOp>(*val_defining_op))
- val = val_defining_op->getOperand(0);
+ if (val_defining_op) {
+ if (llvm::isa<mlir::tensor::CastOp>(*val_defining_op))
+ val = val_defining_op->getOperand(0);
+ }
// Sanity check. This mlir::Value should be built in map since graph
// is DAG
@@ -1501,11 +1505,20 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion(
llvm::errs() << "ERROR: Can't find built mlir::Value key.\n";
return mlir::failure();
}
- std::string output_name =
- "TosaOutput_" + std::to_string(output_tensor_index++);
- tensor_map[val] = output_name;
- block->GetOutputs().push_back(output_name);
- return_values.push_back(val);
+
+ // If returned value is block input, short-circuit the tensor name
+ // Otherwise, build a new output name and override the origin tensor
+ // name
+ if (input_tensor_map.find(val) != input_tensor_map.end()) {
+ block->GetOutputs().push_back(input_tensor_map[val]);
+ return_values.push_back(val);
+ } else {
+ std::string output_name =
+ "TosaOutput_" + std::to_string(output_tensor_index++);
+ tensor_map[val] = output_name;
+ block->GetOutputs().push_back(output_name);
+ return_values.push_back(val);
+ }
}
}
}
@@ -1513,8 +1526,8 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInRegion(
// Build tensor
// The tensor_map is sorted by hashed mlir::Value types.
- // For serialization, sort tensors alphabetically by name for a deterministic
- // and human-friendly ordering.
+ // For serialization, sort tensors alphabetically by name for a
+ // deterministic and human-friendly ordering.
std::map<std::string, mlir::Value> tensor_name_sort;
for (auto pair : tensor_map)
tensor_name_sort[pair.second] = pair.first;