From 5eddcd35c1776784baeeb39e92bad81da826e065 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Wed, 13 Mar 2024 19:19:53 +0000 Subject: [tosa_mlir_translator] Add acc_type to conv ops Add serializing/deserializing acc_type to/from ConvAttribute Signed-off-by: Tai Ly Change-Id: I20780056f467952eb8baf6f5e80d242df6f0b33f --- src/TosaSerialize.cpp | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) (limited to 'src/TosaSerialize.cpp') diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 55a11fd..875303e 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -105,14 +105,16 @@ static DType Type2DType(mlir::Type element_type) { return DType_UNKNOWN; } -static DType Type2PoolAccDType(mlir::Type element_type) { - // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>; +static DType Type2AccDType(mlir::Type element_type) { + // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>; if (element_type.isF32()) { return DType_FP32; } else if (element_type.isF16()) { return DType_FP16; } else if (element_type.isInteger(32)) { return DType_INT32; + } else if (element_type.isInteger(48)) { + return DType_INT48; } return DType_UNKNOWN; } @@ -465,7 +467,7 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, // AvgPool has acc_type, MaxPool does not if (op.hasAttr("acc_type")) { auto acc_type = op.getAttr("acc_type").cast().getValue(); - acc_dtype = Type2PoolAccDType(acc_type); + acc_dtype = Type2AccDType(acc_type); } std::string input_name = GetTensorName(op.getOperand(0)); @@ -735,8 +737,11 @@ TosaSerializationOperatorBuilder::build( ? op.getAttr("local_bound").dyn_cast().getValue() : false; + auto acc_type = op.getAttr("acc_type").cast().getValue(); + auto acc_dtype = Type2AccDType(acc_type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, - local_bound); + local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV2D, Attribute_ConvAttribute, &attribute, @@ -778,8 +783,11 @@ TosaSerializationOperatorBuilder::build( ? op.getAttr("local_bound").dyn_cast().getValue() : false; + auto acc_type = op.getAttr("acc_type").cast().getValue(); + auto acc_dtype = Type2AccDType(acc_type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, - local_bound); + local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV3D, Attribute_ConvAttribute, &attribute, @@ -821,8 +829,11 @@ TosaSerializationOperatorBuilder::build( ? op.getAttr("local_bound").dyn_cast().getValue() : false; + auto acc_type = op.getAttr("acc_type").cast().getValue(); + auto acc_dtype = Type2AccDType(acc_type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, - local_bound); + local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, @@ -867,8 +878,11 @@ TosaSerializationOperatorBuilder::build( ? op.getAttr("local_bound").dyn_cast().getValue() : false; + auto acc_type = op.getAttr("acc_type").cast().getValue(); + auto acc_dtype = Type2AccDType(acc_type); + TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp, - weight_zp, local_bound); + weight_zp, local_bound, acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, -- cgit v1.2.1