diff options
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r-- | src/TosaDeserialize.cpp | 55 |
1 files changed, 28 insertions, 27 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 5704b04..eb60173 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -513,10 +513,11 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_AVG_POOL2D>( mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>( loc, output_type, input_val, kernel, stride, pad, acc_attr); } else { - auto quant = op_builder->getAttr<mlir::tosa::UnaryOpQuantizationAttr>( - input_zp, output_zp); + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp); mlir_op = op_builder->create<mlir::tosa::AvgPool2dOp>( - loc, output_type, input_val, kernel, stride, pad, acc_attr, quant); + loc, output_type, input_val, kernel, stride, pad, acc_attr, + input_zp_attr, output_zp_attr); } block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); @@ -774,18 +775,17 @@ TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const { auto weight_zp = attr->weight_zp(); bool local_bound = attr->local_bound(); - // quantizationattr is required for quantized type, and not allowed for float - // type + // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; if (output_type.getElementType().isa<mlir::FloatType>()) { assert(input_zp == 0 && weight_zp == 0); } - auto quant = op_builder->getAttr<mlir::tosa::ConvOpQuantizationAttr>( - input_zp, weight_zp); - mlir_op = op_builder->create<MLIR_OP>(loc, output_type, input0_val, - input1_val, input2_val, pad, stride, - dilation, quant, local_bound); + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); + mlir_op = op_builder->create<MLIR_OP>( + loc, output_type, input0_val, input1_val, input2_val, pad, stride, + dilation, input_zp_attr, weight_zp_attr, local_bound); block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); @@ -826,17 +826,18 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE_CONV2D>( auto weight_zp = attr->weight_zp(); bool local_bound = attr->local_bound(); - // quantizationattr is required for quantized type, and not allowed for float - // type + // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; if (output_type.getElementType().isa<mlir::FloatType>()) { assert(input_zp == 0 && weight_zp == 0); } - auto quant = op_builder->getAttr<mlir::tosa::ConvOpQuantizationAttr>( - input_zp, weight_zp); + + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); + mlir_op = op_builder->create<mlir::tosa::TransposeConv2DOp>( loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, - output_shape, quant, local_bound); + output_shape, input_zp_attr, weight_zp_attr, local_bound); block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); @@ -858,18 +859,18 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_FULLY_CONNECTED>( auto input_zp = attr->input_zp(); auto weight_zp = attr->weight_zp(); - // quantizationattr is required for quantized type, and not allowed for float - // type + // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; if (output_type.getElementType().isa<mlir::FloatType>()) { assert(input_zp == 0 && weight_zp == 0); mlir_op = op_builder->create<mlir::tosa::FullyConnectedOp>( loc, output_type, input0_val, input1_val, input2_val); } else { - auto quant = op_builder->getAttr<mlir::tosa::ConvOpQuantizationAttr>( - input_zp, weight_zp); + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); mlir_op = op_builder->create<mlir::tosa::FullyConnectedOp>( - loc, output_type, input0_val, input1_val, input2_val, quant); + loc, output_type, input0_val, input1_val, input2_val, input_zp_attr, + weight_zp_attr); } block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); @@ -895,10 +896,10 @@ TosaMlirOperatorBuilder::build<Op_MATMUL>(TosaSerializationOperator *op) const { mlir_op = op_builder->create<mlir::tosa::MatMulOp>(loc, output_type, input0_val, input1_val); } else { - auto quant = - op_builder->getAttr<mlir::tosa::MatMulOpQuantizationAttr>(A_zp, B_zp); + auto a_zp_attr = op_builder->getI32IntegerAttr(A_zp); + auto b_zp_attr = op_builder->getI32IntegerAttr(B_zp); mlir_op = op_builder->create<mlir::tosa::MatMulOp>( - loc, output_type, input0_val, input1_val, quant); + loc, output_type, input0_val, input1_val, a_zp_attr, b_zp_attr); } block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); @@ -1010,10 +1011,10 @@ TosaMlirOperatorBuilder::build<Op_NEGATE>(TosaSerializationOperator *op) const { mlir_op = op_builder->create<mlir::tosa::NegateOp>(loc, output_type, input_val); } else { - auto quant = op_builder->getAttr<mlir::tosa::UnaryOpQuantizationAttr>( - input_zp, output_zp); - mlir_op = op_builder->create<mlir::tosa::NegateOp>(loc, output_type, - input_val, quant); + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp); + mlir_op = op_builder->create<mlir::tosa::NegateOp>( + loc, output_type, input_val, input_zp_attr, output_zp_attr); } block->push_back(mlir_op); |