aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp22
1 files changed, 20 insertions, 2 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 153f16f..45e1f18 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -175,6 +175,22 @@ const std::string ResizeEnum2Str(const tosa::ResizeMode &mode) {
return "";
}
+// this is a counter part to Type2PoolAccumDType
+mlir::TypeAttr AccumDType2TypeAttr(mlir::OpBuilder *op_builder, DType dtype) {
+ // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
+ if (dtype == DType_INT32) {
+ return mlir::TypeAttr::get(op_builder->getI32Type());
+ } else if (dtype == DType_FP32) {
+ return mlir::TypeAttr::get(op_builder->getF32Type());
+ } else if (dtype == DType_FP16) {
+ return mlir::TypeAttr::get(op_builder->getF16Type());
+ } else {
+ // unknown accum type
+ // for now, default to F32
+ return mlir::TypeAttr::get(op_builder->getF32Type());
+ }
+}
+
} // namespace
class TosaMlirRegionBuilder;
@@ -290,17 +306,19 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>(
mlir::DenseI64ArrayAttr stride =
BuildDenseI64ArrayAttr(op_builder, attr->stride());
mlir::DenseI64ArrayAttr pad = BuildDenseI64ArrayAttr(op_builder, attr->pad());
+ auto acc_attr = AccumDType2TypeAttr(op_builder, attr->accum_dtype());
+
int32_t input_zp = attr->input_zp();
int32_t output_zp = attr->output_zp();
mlir::Operation *mlir_op;
if (input_zp == 0 && output_zp == 0) {
mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>(
- loc, output_type, input_val, kernel, stride, pad);
+ loc, output_type, input_val, kernel, stride, pad, acc_attr);
} else {
auto quant = op_builder->getAttr<mlir::tosa::UnaryOpQuantizationAttr>(
input_zp, output_zp);
mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>(
- loc, output_type, input_val, kernel, stride, pad, quant);
+ loc, output_type, input_val, kernel, stride, pad, acc_attr, quant);
}
block->push_back(mlir_op);
return std::vector<mlir::Value>({mlir_op->getResult(0)});