diff options
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 9a6cf84..5b0d2bd 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -105,13 +105,13 @@ static DType Type2DType(mlir::Type element_type) { return DType_UNKNOWN; } -static DType Type2PoolAccumDType(mlir::Type element_type) { - // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>; +static DType Type2PoolAccDType(mlir::Type element_type) { + // def Tosa_AccType : AnyTypeOf<[I<32>, F16, F32]>; if (element_type.isF32()) { return DType_FP32; } else if (element_type.isF16()) { return DType_FP16; - } else if (element_type.isInteger(32) || element_type.isSignedInteger(32)) { + } else if (element_type.isInteger(32)) { return DType_INT32; } return DType_UNKNOWN; @@ -461,11 +461,11 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, auto kernel = getDenseI64ArrayAttr<int>(op.getAttr("kernel")); ASSERT_VECTOR_LENGTH(kernel, 2); - DType accum_dtype = DType_FP32; - // AvgPool has accum_dtype, MaxPool does not + DType acc_dtype = DType_FP32; + // AvgPool has acc_type, MaxPool does not if (op.hasAttr("acc_type")) { auto acc_type = op.getAttr("acc_type").cast<mlir::TypeAttr>().getValue(); - accum_dtype = Type2PoolAccumDType(acc_type); + acc_dtype = Type2PoolAccDType(acc_type); } std::string input_name = GetTensorName(op.getOperand(0)); @@ -482,7 +482,7 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, : 0; TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, - accum_dtype); + acc_dtype); TosaSerializationOperator *tyop = new TosaSerializationOperator(opcode, Attribute_PoolAttribute, &attribute, |