diff options
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r-- | src/TosaDeserialize.cpp | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 153f16f..45e1f18 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -175,6 +175,22 @@ const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) { return ""; } +// this is a counter part to Type2PoolAccumDType +mlir::TypeAttr AccumDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) { + // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>; + if (dtype == DType_INT32) { + return mlir::TypeAttr::get(op_builder->getI32Type()); + } else if (dtype == DType_FP32) { + return mlir::TypeAttr::get(op_builder->getF32Type()); + } else if (dtype == DType_FP16) { + return mlir::TypeAttr::get(op_builder->getF16Type()); + } else { + // unknown accum type + // for now, default to F32 + return mlir::TypeAttr::get(op_builder->getF32Type()); + } +} + } // namespace class TosaMlirRegionBuilder; @@ -290,17 +306,19 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>( mlir::DenseI64ArrayAttr stride = BuildDenseI64ArrayAttr(op_builder, attr->stride()); mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad()); + auto acc_attr = AccumDType2TypeAttr(op_builder, attr->accum_dtype()); + int32_t input_zp = attr->input_zp(); int32_t output_zp = attr->output_zp(); mlir::Operation *mlir_op; if (input_zp == 0 && output_zp == 0) { mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>( - loc, output_type, input_val, kernel, stride, pad); + loc, output_type, input_val, kernel, stride, pad, acc_attr); } else { auto quant = op_builder->getAttr<mlir::tosa::UnaryOpQuantizationAttr>( input_zp, output_zp); mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>( - loc, output_type, input_val, kernel, stride, pad, quant); + loc, output_type, input_val, kernel, stride, pad, acc_attr, quant); } block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); |