From ea49f62f7ab81750f19bef011683164fe9bd4080 Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Thu, 10 Aug 2023 23:00:22 -0700 Subject: Add DIM operator and its serialization/deserialization to TOSA MLIR translator Signed-off-by: Won Jeon Change-Id: Idf45f0a68f551039cc40a69d90aa5c53098bc238 --- include/operator.def | 1 + include/schema_operator.def | 1 + src/TosaDeserialize.cpp | 22 ++++++++++++++++++++++ src/TosaSerialize.cpp | 21 +++++++++++++++++++++ third_party/serialization_lib | 2 +- 5 files changed, 46 insertions(+), 1 deletion(-) diff --git a/include/operator.def b/include/operator.def index 10cbc5d..6198c0e 100644 --- a/include/operator.def +++ b/include/operator.def @@ -91,6 +91,7 @@ DEF_OPERATOR(ReduceSum) /* memory operation */ DEF_OPERATOR(Concat) DEF_OPERATOR(Pad) +DEF_OPERATOR(Dim) DEF_OPERATOR(Reshape) DEF_OPERATOR(Reverse) DEF_OPERATOR(Slice) diff --git a/include/schema_operator.def b/include/schema_operator.def index a11eeeb..52c7ae4 100644 --- a/include/schema_operator.def +++ b/include/schema_operator.def @@ -75,6 +75,7 @@ DEF_SCHEMA_OPERATOR(REDUCE_PRODUCT) DEF_SCHEMA_OPERATOR(REDUCE_SUM) DEF_SCHEMA_OPERATOR(CONCAT) DEF_SCHEMA_OPERATOR(PAD) +DEF_SCHEMA_OPERATOR(DIM) DEF_SCHEMA_OPERATOR(RESHAPE) DEF_SCHEMA_OPERATOR(REVERSE) DEF_SCHEMA_OPERATOR(SLICE) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 70a19a7..9e646c7 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -87,6 +87,9 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, case DType_UINT16: element_type = op_builder->getIntegerType(16, false); break; + case DType_SHAPE: + element_type = op_builder->getIntegerType(64); + break; default: llvm::errs() << "ERROR: unknown type " << EnumNamesDType()[ts->GetDtype()] << "\n"; @@ -859,6 +862,25 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { return std::vector({mlir_op->getResult(0)}); } +template <> +std::vector +TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_AxisAttribute); // double check attribute type + TosaAxisAttribute *attr = + static_cast(op->GetAttribute()); + auto axis = op_builder->getI32IntegerAttr(attr->axis()); + + mlir::Operation *mlir_op = + op_builder->create(loc, output_type, input_val, axis); + block->push_back(mlir_op); + return std::vector({mlir_op->getResult(0)}); +} + template <> std::vector TosaMlirOperatorBuilder::build( TosaSerializationOperator *op) const { diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 17a693e..741597b 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -86,6 +86,9 @@ static DType Type2DType(mlir::Type element_type) { return DType_INT32; } else if (element_type.isInteger(48)) { return DType_INT48; + } else if (element_type.isInteger(64)) { + // shape treated as integer with bitwidth 64 for now + return DType_SHAPE; } // boolean in MLIR treated as integer with bitwidth 1 else if (element_type.isInteger(1)) { @@ -974,6 +977,24 @@ TosaSerializationOperatorBuilder::build( return tyop; } +template <> +TosaSerializationOperator * +TosaSerializationOperatorBuilder::build( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + int32_t axis = op.getAttr("axis").dyn_cast().getInt(); + TosaAxisAttribute attribute(axis); + + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_DIM, Attribute_AxisAttribute, &attribute, + std::vector{input_name}, + std::vector{output_name}); + + return tyop; +} + template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build( diff --git a/third_party/serialization_lib b/third_party/serialization_lib index 89963aa..1adc5d0 160000 --- a/third_party/serialization_lib +++ b/third_party/serialization_lib @@ -1 +1 @@ -Subproject commit 89963aa8fad822ab7a6e1ff92f6b7b4ee0b9350c +Subproject commit 1adc5d05d9fd21591790678a3f1cdaa4c5b347c4 -- cgit v1.2.1