diff options
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r-- | src/TosaDeserialize.cpp | 25 |
1 files changed, 15 insertions, 10 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index b12d22a..6fa691e 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -356,19 +356,21 @@ const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) { return ""; } -// this is a counter part to Type2PoolAccDType -mlir::TypeAttr AccDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) { - // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>; +// this is a counter part to Type2AccDType +mlir::Type AccDType2Type(mlir::OpBuilder *op_builder, DType dtype) { + // def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>; if (dtype == DType_INT32) { - return mlir::TypeAttr::get(op_builder->getI32Type()); + return op_builder->getI32Type(); + } else if (dtype == DType_INT48) { + return op_builder->getIntegerType(48); } else if (dtype == DType_FP32) { - return mlir::TypeAttr::get(op_builder->getF32Type()); + return op_builder->getF32Type(); } else if (dtype == DType_FP16) { - return mlir::TypeAttr::get(op_builder->getF16Type()); + return op_builder->getF16Type(); } else { // unknown acc type // for now, default to F32 - return mlir::TypeAttr::get(op_builder->getF32Type()); + return op_builder->getF32Type(); } } @@ -504,7 +506,8 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>( mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride()); mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); - auto acc_attr = AccDType2TypeAttr(op_builder, attr->acc_type()); + auto acc_attr = + mlir::TypeAttr::get(AccDType2Type(op_builder, attr->acc_type())); int32_t input_zp = attr->input_zp(); int32_t output_zp = attr->output_zp(); @@ -776,6 +779,7 @@ TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const { auto input_zp = attr->input_zp(); auto weight_zp = attr->weight_zp(); bool local_bound = attr->local_bound(); + auto acc_type = AccDType2Type(op_builder, attr->acc_type()); // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; @@ -787,7 +791,7 @@ TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const { auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); mlir_op = op_builder->create<MLIR_OP>( loc, output_type, input0_val, input1_val, input2_val, pad, stride, - dilation, input_zp_attr, weight_zp_attr, local_bound); + dilation, acc_type, input_zp_attr, weight_zp_attr, local_bound); block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); @@ -827,6 +831,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE_CONV2D>( auto input_zp = attr->input_zp(); auto weight_zp = attr->weight_zp(); bool local_bound = attr->local_bound(); + auto acc_type = AccDType2Type(op_builder, attr->acc_type()); // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; @@ -839,7 +844,7 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE_CONV2D>( mlir_op = op_builder->create<mlir::tosa::TransposeConv2DOp>( loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, - output_shape, input_zp_attr, weight_zp_attr, local_bound); + output_shape, acc_type, input_zp_attr, weight_zp_attr, local_bound); block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); |