diff options
author | Won Jeon <won.jeon@arm.com> | 2023-08-10 23:00:22 -0700 |
---|---|---|
committer | Won Jeon <won.jeon@arm.com> | 2023-08-16 15:44:40 -0700 |
commit | ea49f62f7ab81750f19bef011683164fe9bd4080 (patch) | |
tree | d70004aa7024cfc6c3be6f92cce70d404e0a8b95 /src | |
parent | 5857c5ddaf849909cbd2ceb445b7ec9cb5c9ae43 (diff) | |
download | tosa_mlir_translator-ea49f62f7ab81750f19bef011683164fe9bd4080.tar.gz |
Add DIM operator and its serialization/deserialization to TOSA MLIR translator
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: Idf45f0a68f551039cc40a69d90aa5c53098bc238
Diffstat (limited to 'src')
-rw-r--r-- | src/TosaDeserialize.cpp | 22 | ||||
-rw-r--r-- | src/TosaSerialize.cpp | 21 |
2 files changed, 43 insertions, 0 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 70a19a7..9e646c7 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -87,6 +87,9 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, case DType_UINT16: element_type = op_builder->getIntegerType(16, false); break; + case DType_SHAPE: + element_type = op_builder->getIntegerType(64); + break; default: llvm::errs() << "ERROR: unknown type " << EnumNamesDType()[ts->GetDtype()] << "\n"; @@ -860,6 +863,25 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const { } template <> +std::vector<mlir::Value> +TosaMlirOperatorBuilder::build<Op_DIM>(TosaSerializationOperator *op) const { + mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); + mlir::RankedTensorType output_type = + tensor_type_map->at(op->GetOutputTensorNames()[0]); + + assert(op->GetAttributeType() == + Attribute_AxisAttribute); // double check attribute type + TosaAxisAttribute *attr = + static_cast<TosaAxisAttribute *>(op->GetAttribute()); + auto axis = op_builder->getI32IntegerAttr(attr->axis()); + + mlir::Operation *mlir_op = + op_builder->create<mlir::tosa::DimOp>(loc, output_type, input_val, axis); + block->push_back(mlir_op); + return std::vector<mlir::Value>({mlir_op->getResult(0)}); +} + +template <> std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE>( TosaSerializationOperator *op) const { mlir::Value input_val = tensor_map->at(op->GetInputTensorNames()[0]); diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 17a693e..741597b 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -86,6 +86,9 @@ static DType Type2DType(mlir::Type element_type) { return DType_INT32; } else if (element_type.isInteger(48)) { return DType_INT48; + } else if (element_type.isInteger(64)) { + // shape treated as integer with bitwidth 64 for now + return DType_SHAPE; } // boolean in MLIR treated as integer with bitwidth 1 else if (element_type.isInteger(1)) { @@ -976,6 +979,24 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( template <> TosaSerializationOperator * +TosaSerializationOperatorBuilder::build<mlir::tosa::DimOp>( + mlir::Operation &op) const { + std::string input_name = GetTensorName(op.getOperand(0)); + std::string output_name = GetTensorName(op.getResult(0)); + + int32_t axis = op.getAttr("axis").dyn_cast<mlir::IntegerAttr>().getInt(); + TosaAxisAttribute attribute(axis); + + TosaSerializationOperator *tyop = + new TosaSerializationOperator(Op_DIM, Attribute_AxisAttribute, &attribute, + std::vector<std::string>{input_name}, + std::vector<std::string>{output_name}); + + return tyop; +} + +template <> +TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeOp>( mlir::Operation &op) const { std::string input_name = GetTensorName(op.getOperand(0)); |