aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 17a693e..741597b 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -86,6 +86,9 @@ static DType Type2DType(mlir::Type element_type) {
return DType_INT32;
} else if (element_type.isInteger(48)) {
return DType_INT48;
+ } else if (element_type.isInteger(64)) {
+ // shape treated as integer with bitwidth 64 for now
+ return DType_SHAPE;
}
// boolean in MLIR treated as integer with bitwidth 1
else if (element_type.isInteger(1)) {
@@ -976,6 +979,24 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
template <>
TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::DimOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt();
+ TosaAxisAttribute attribute(axis);
+
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_DIM, Attribute_AxisAttribute, &attribute,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeOp>(
mlir::Operation &op) const {
std::string input_name = GetTensorName(op.getOperand(0));