aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2023-08-10 23:00:22 -0700
committerWon Jeon <won.jeon@arm.com>2023-08-16 15:44:40 -0700
commitea49f62f7ab81750f19bef011683164fe9bd4080 (patch)
treed70004aa7024cfc6c3be6f92cce70d404e0a8b95
parent5857c5ddaf849909cbd2ceb445b7ec9cb5c9ae43 (diff)
downloadtosa_mlir_translator-ea49f62f7ab81750f19bef011683164fe9bd4080.tar.gz
Add DIM operator and its serialization/deserialization to TOSA MLIR translator
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: Idf45f0a68f551039cc40a69d90aa5c53098bc238
-rw-r--r--include/operator.def1
-rw-r--r--include/schema_operator.def1
-rw-r--r--src/TosaDeserialize.cpp22
-rw-r--r--src/TosaSerialize.cpp21
m---------third_party/serialization_lib0
5 files changed, 45 insertions, 0 deletions
diff --git a/include/operator.def b/include/operator.def
index 10cbc5d..6198c0e 100644
--- a/include/operator.def
+++ b/include/operator.def
@@ -91,6 +91,7 @@ DEF_OPERATOR(ReduceSum)
/* memory operation */
DEF_OPERATOR(Concat)
DEF_OPERATOR(Pad)
+DEF_OPERATOR(Dim)
DEF_OPERATOR(Reshape)
DEF_OPERATOR(Reverse)
DEF_OPERATOR(Slice)
diff --git a/include/schema_operator.def b/include/schema_operator.def
index a11eeeb..52c7ae4 100644
--- a/include/schema_operator.def
+++ b/include/schema_operator.def
@@ -75,6 +75,7 @@ DEF_SCHEMA_OPERATOR(REDUCE_PRODUCT)
DEF_SCHEMA_OPERATOR(REDUCE_SUM)
DEF_SCHEMA_OPERATOR(CONCAT)
DEF_SCHEMA_OPERATOR(PAD)
+DEF_SCHEMA_OPERATOR(DIM)
DEF_SCHEMA_OPERATOR(RESHAPE)
DEF_SCHEMA_OPERATOR(REVERSE)
DEF_SCHEMA_OPERATOR(SLICE)
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 70a19a7..9e646c7 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -87,6 +87,9 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
case DType_UINT16:
element_type = op_builder->getIntegerType(16, false);
break;
+ case DType_SHAPE:
+ element_type = op_builder->getIntegerType(64);
+ break;
default:
llvm::errs() << "ERROR: unknown type " << EnumNamesDType()[ts->GetDtype()]
<< "\n";
@@ -860,6 +863,25 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
}
template <>
+std::vector<mlir::Value>
+TosaMlirOperatorBuilder::build<Op_DIM>(TosaSerializationOperator *op) const {
+ mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
+ mlir::RankedTensorType output_type =
+ tensor_type_map->at(op->GetOutputTensorNames()[0]);
+
+ assert(op->GetAttributeType() ==
+ Attribute_AxisAttribute); // double check attribute type
+ TosaAxisAttribute *attr =
+ static_cast<TosaAxisAttribute *>(op->GetAttribute());
+ auto axis = op_builder->getI32IntegerAttr(attr->axis());
+
+ mlir::Operation *mlir_op =
+ op_builder->create<mlir::tosa::DimOp>(loc, output_type, input_val, axis);
+ block->push_back(mlir_op);
+ return std::vector<mlir::Value>({mlir_op->getResult(0)});
+}
+
+template <>
std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE>(
TosaSerializationOperator *op) const {
mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]);
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 17a693e..741597b 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -86,6 +86,9 @@ static DType Type2DType(mlir::Type element_type) {
return DType_INT32;
} else if (element_type.isInteger(48)) {
return DType_INT48;
+ } else if (element_type.isInteger(64)) {
+ // shape treated as integer with bitwidth 64 for now
+ return DType_SHAPE;
}
// boolean in MLIR treated as integer with bitwidth 1
else if (element_type.isInteger(1)) {
@@ -976,6 +979,24 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
template <>
TosaSerializationOperator *
+TosaSerializationOperatorBuilder::build<mlir::tosa::DimOp>(
+ mlir::Operation &op) const {
+ std::string input_name = GetTensorName(op.getOperand(0));
+ std::string output_name = GetTensorName(op.getResult(0));
+
+ int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt();
+ TosaAxisAttribute attribute(axis);
+
+ TosaSerializationOperator *tyop =
+ new TosaSerializationOperator(Op_DIM, Attribute_AxisAttribute, &attribute,
+ std::vector<std::string>{input_name},
+ std::vector<std::string>{output_name});
+
+ return tyop;
+}
+
+template <>
+TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeOp>(
mlir::Operation &op) const {
std::string input_name = GetTensorName(op.getOperand(0));
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject 89963aa8fad822ab7a6e1ff92f6b7b4ee0b9350
+Subproject 1adc5d05d9fd21591790678a3f1cdaa4c5b347c