From 5eddcd35c1776784baeeb39e92bad81da826e065 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Wed, 13 Mar 2024 19:19:53 +0000 Subject: [tosa_mlir_translator] Add acc_type to conv ops Add serializing/deserializing acc_type to/from ConvAttribute Signed-off-by: Tai Ly Change-Id: I20780056f467952eb8baf6f5e80d242df6f0b33f --- src/TosaDeserialize.cpp | 25 +++++++++++++++---------- src/TosaSerialize.cpp | 28 +++++++++++++++++++++------- third_party/serialization_lib | 2 +- 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index b12d22a..6fa691e 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -356,19 +356,21 @@ const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) { return ""; } -// this is a counter part to Type2PoolAccDType -mlir::TypeAttr AccDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) { - // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>; +// this is a counter part to Type2AccDType +mlir::Type AccDType2Type(mlir::OpBuilder *op_builder, DType dtype) { + // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>; if (dtype == DType_INT32) { - return mlir::TypeAttr::get(op_builder->getI32Type()); + return op_builder->getI32Type(); + } else if (dtype == DType_INT48) { + return op_builder->getIntegerType(48); } else if (dtype == DType_FP32) { - return mlir::TypeAttr::get(op_builder->getF32Type()); + return op_builder->getF32Type(); } else if (dtype == DType_FP16) { - return mlir::TypeAttr::get(op_builder->getF16Type()); + return op_builder->getF16Type(); } else { // unknown acc type // for now, default to F32 - return mlir::TypeAttr::get(op_builder->getF32Type()); + return op_builder->getF32Type(); } } @@ -504,7 +506,8 @@ std::vector TosaMlirOperatorBuilder::build( mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride()); mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); - auto acc_attr = AccDType2TypeAttr(op_builder, attr->acc_type()); + auto acc_attr = + mlir::TypeAttr::get(AccDType2Type(op_builder, attr->acc_type())); int32_t input_zp = attr->input_zp(); int32_t output_zp = attr->output_zp(); @@ -776,6 +779,7 @@ TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const { auto input_zp = attr->input_zp(); auto weight_zp = attr->weight_zp(); bool local_bound = attr->local_bound(); + auto acc_type = AccDType2Type(op_builder, attr->acc_type()); // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; @@ -787,7 +791,7 @@ TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const { auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, input2_val, pad, stride, - dilation, input_zp_attr, weight_zp_attr, local_bound); + dilation, acc_type, input_zp_attr, weight_zp_attr, local_bound); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); @@ -827,6 +831,7 @@ std::vector TosaMlirOperatorBuilder::build( auto input_zp = attr->input_zp(); auto weight_zp = attr->weight_zp(); bool local_bound = attr->local_bound(); + auto acc_type = AccDType2Type(op_builder, attr->acc_type()); // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; @@ -839,7 +844,7 @@ std::vector TosaMlirOperatorBuilder::build( mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, - output_shape, input_zp_attr, weight_zp_attr, local_bound); + output_shape, acc_type, input_zp_attr, weight_zp_attr, local_bound); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 55a11fd..875303e 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -105,14 +105,16 @@ static DType Type2DType(mlir::Type element_type) { return DType_UNKNOWN; } -static DType Type2PoolAccDType(mlir::Type element_type) { - // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>; +static DType Type2AccDType(mlir::Type element_type) { + // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>; if (element_type.isF32()) { return DType_FP32; } else if (element_type.isF16()) { return DType_FP16; } else if (element_type.isInteger(32)) { return DType_INT32; + } else if (element_type.isInteger(48)) { + return DType_INT48; } return DType_UNKNOWN; } @@ -465,7 +467,7 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, // AvgPool has acc_type, MaxPool does not if (op.hasAttr("acc_type")) { auto acc_type = op.getAttr("acc_type").cast().getValue(); - acc_dtype = Type2PoolAccDType(acc_type); + acc_dtype = Type2AccDType(acc_type); } std::string input_name = GetTensorName(op.getOperand(0)); @@ -735,8 +737,11 @@ TosaSerializationOperatorBuilder::build( ? op.getAttr("local_bound").dyn_cast().getValue() : false; + auto acc_type = op.getAttr("acc_type").cast().getValue(); + auto acc_dtype = Type2AccDType(acc_type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, - local_bound); + local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV2D, Attribute_ConvAttribute, &attribute, @@ -778,8 +783,11 @@ TosaSerializationOperatorBuilder::build( ? op.getAttr("local_bound").dyn_cast().getValue() : false; + auto acc_type = op.getAttr("acc_type").cast().getValue(); + auto acc_dtype = Type2AccDType(acc_type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, - local_bound); + local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV3D, Attribute_ConvAttribute, &attribute, @@ -821,8 +829,11 @@ TosaSerializationOperatorBuilder::build( ? op.getAttr("local_bound").dyn_cast().getValue() : false; + auto acc_type = op.getAttr("acc_type").cast().getValue(); + auto acc_dtype = Type2AccDType(acc_type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, - local_bound); + local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, @@ -867,8 +878,11 @@ TosaSerializationOperatorBuilder::build( ? op.getAttr("local_bound").dyn_cast().getValue() : false; + auto acc_type = op.getAttr("acc_type").cast().getValue(); + auto acc_dtype = Type2AccDType(acc_type); + TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp, - weight_zp, local_bound); + weight_zp, local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, diff --git a/third_party/serialization_lib b/third_party/serialization_lib index 0b6d7c2..ad78daa 160000 --- a/third_party/serialization_lib +++ b/third_party/serialization_lib @@ -1 +1 @@ -Subproject commit 0b6d7c271af1e6593e6a2cf14b32acea765f4b64 +Subproject commit ad78daaf0fa1e41742cbed314459c3dbbb483c20 -- cgit v1.2.1