From f5645effe818c2b0cb0124597bc761850616ba76 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Tue, 13 Feb 2024 19:38:13 +0000 Subject: [tosa_mlir_translator] Refactor QuantizationAttr changes to serialization/deserialization due to removal of quantization attr in TOSA dialect Signed-off-by: Tai Ly Change-Id: I0903d2c1c62bc50822e6c08bb869ec135c986ff3 --- src/TosaDeserialize.cpp | 55 +++++++++++----------- src/TosaSerialize.cpp | 120 ++++++++++++++++++++++++++---------------------- 2 files changed, 94 insertions(+), 81 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 5704b04..eb60173 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -513,10 +513,11 @@ std::vector TosaMlirOperatorBuilder::build( mlir_op = op_builder->create( loc, output_type, input_val, kernel, stride, pad, acc_attr); } else { - auto quant = op_builder->getAttr( - input_zp, output_zp); + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp); mlir_op = op_builder->create( - loc, output_type, input_val, kernel, stride, pad, acc_attr, quant); + loc, output_type, input_val, kernel, stride, pad, acc_attr, + input_zp_attr, output_zp_attr); } block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); @@ -774,18 +775,17 @@ TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const { auto weight_zp = attr->weight_zp(); bool local_bound = attr->local_bound(); - // quantizationattr is required for quantized type, and not allowed for float - // type + // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; if (output_type.getElementType().isa()) { assert(input_zp == 0 && weight_zp == 0); } - auto quant = op_builder->getAttr( - input_zp, weight_zp); - mlir_op = op_builder->create(loc, output_type, input0_val, - input1_val, input2_val, pad, stride, - dilation, quant, local_bound); + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); + mlir_op = op_builder->create( + loc, output_type, input0_val, input1_val, input2_val, pad, stride, + dilation, input_zp_attr, weight_zp_attr, local_bound); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); @@ -826,17 +826,18 @@ std::vector TosaMlirOperatorBuilder::build( auto weight_zp = attr->weight_zp(); bool local_bound = attr->local_bound(); - // quantizationattr is required for quantized type, and not allowed for float - // type + // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; if (output_type.getElementType().isa()) { assert(input_zp == 0 && weight_zp == 0); } - auto quant = op_builder->getAttr( - input_zp, weight_zp); + + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); + mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, - output_shape, quant, local_bound); + output_shape, input_zp_attr, weight_zp_attr, local_bound); block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); @@ -858,18 +859,18 @@ std::vector TosaMlirOperatorBuilder::build( auto input_zp = attr->input_zp(); auto weight_zp = attr->weight_zp(); - // quantizationattr is required for quantized type, and not allowed for float - // type + // input_zp/weight_zp is not allowed for float type mlir::Operation *mlir_op; if (output_type.getElementType().isa()) { assert(input_zp == 0 && weight_zp == 0); mlir_op = op_builder->create( loc, output_type, input0_val, input1_val, input2_val); } else { - auto quant = op_builder->getAttr( - input_zp, weight_zp); + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto weight_zp_attr = op_builder->getI32IntegerAttr(weight_zp); mlir_op = op_builder->create( - loc, output_type, input0_val, input1_val, input2_val, quant); + loc, output_type, input0_val, input1_val, input2_val, input_zp_attr, + weight_zp_attr); } block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); @@ -895,10 +896,10 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir_op = op_builder->create(loc, output_type, input0_val, input1_val); } else { - auto quant = - op_builder->getAttr(A_zp, B_zp); + auto a_zp_attr = op_builder->getI32IntegerAttr(A_zp); + auto b_zp_attr = op_builder->getI32IntegerAttr(B_zp); mlir_op = op_builder->create( - loc, output_type, input0_val, input1_val, quant); + loc, output_type, input0_val, input1_val, a_zp_attr, b_zp_attr); } block->push_back(mlir_op); return std::vector({mlir_op->getResult(0)}); @@ -1010,10 +1011,10 @@ TosaMlirOperatorBuilder::build(TosaSerializationOperator *op) const { mlir_op = op_builder->create(loc, output_type, input_val); } else { - auto quant = op_builder->getAttr( - input_zp, output_zp); - mlir_op = op_builder->create(loc, output_type, - input_val, quant); + auto input_zp_attr = op_builder->getI32IntegerAttr(input_zp); + auto output_zp_attr = op_builder->getI32IntegerAttr(output_zp); + mlir_op = op_builder->create( + loc, output_type, input_val, input_zp_attr, output_zp_attr); } block->push_back(mlir_op); diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 05c7812..0f5056d 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -471,14 +471,16 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = op.getAttrOfType( - "quantization_info"); - - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t output_zp = quant_info ? quant_info.getOutputZp() : 0; + int32_t input_zp = + op.hasAttr("input_zp") + ? input_zp = op.getAttr("input_zp").cast().getInt() + : 0; + int32_t output_zp = + op.hasAttr("output_zp") + ? output_zp = + op.getAttr("output_zp").cast().getInt() + : 0; - mlir::RankedTensorType tensor = - op.getOperand(0).getType().cast(); TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, accum_dtype); @@ -717,13 +719,14 @@ TosaSerializationOperatorBuilder::build( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType("quantization_info"); - - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; - mlir::RankedTensorType tensor = - op.getOperand(0).getType().cast(); + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast().getInt() + : 0; bool local_bound = op.hasAttr("local_bound") @@ -759,13 +762,14 @@ TosaSerializationOperatorBuilder::build( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType("quantization_info"); - - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; - mlir::RankedTensorType tensor = - op.getOperand(0).getType().cast(); + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast().getInt() + : 0; bool local_bound = op.hasAttr("local_bound") @@ -801,13 +805,14 @@ TosaSerializationOperatorBuilder::build( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType("quantization_info"); - - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; - mlir::RankedTensorType tensor = - op.getOperand(0).getType().cast(); + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast().getInt() + : 0; bool local_bound = op.hasAttr("local_bound") @@ -843,11 +848,15 @@ TosaSerializationOperatorBuilder::build( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType("quantization_info"); + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast().getInt() + : 0; - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast(); @@ -876,14 +885,15 @@ TosaSerializationOperatorBuilder::build( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType("quantization_info"); - - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast().getInt() + : 0; - mlir::RankedTensorType tensor = - op.getOperand(0).getType().cast(); TosaFullyConnectedAttribute attribute(input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( @@ -902,13 +912,12 @@ TosaSerializationOperatorBuilder::build( std::string input1_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = op.getAttrOfType( - "quantization_info"); - - int32_t A_zp = quant_info ? quant_info.getAZp() : 0; - int32_t B_zp = quant_info ? quant_info.getBZp() : 0; - mlir::RankedTensorType tensor = - op.getOperand(0).getType().cast(); + int32_t A_zp = op.hasAttr("a_zp") + ? op.getAttr("a_zp").cast().getInt() + : 0; + int32_t B_zp = op.hasAttr("b_zp") + ? op.getAttr("b_zp").cast().getInt() + : 0; TosaMatMulAttribute attribute(A_zp, B_zp); @@ -1035,13 +1044,16 @@ TosaSerializationOperatorBuilder::build( std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = op.getAttrOfType( - "quantization_info"); - - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t output_zp = quant_info ? quant_info.getOutputZp() : 0; + int32_t input1_zp = + op.hasAttr("input1_zp") + ? op.getAttr("input1_zp").cast().getInt() + : 0; + int32_t output_zp = + op.hasAttr("output_zp") + ? op.getAttr("output_zp").cast().getInt() + : 0; - TosaNegateAttribute attribute(input_zp, output_zp); + TosaNegateAttribute attribute(input1_zp, output_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_NEGATE, Attribute_NegateAttribute, &attribute, @@ -1075,9 +1087,9 @@ TosaSerializationOperatorBuilder::build( std::string output_name = GetTensorName(op.getResult(0)); auto pad_op = llvm::cast(op); - auto quant_info = pad_op.getQuantizationInfoAttr(); + auto input_zp_attr = pad_op.getInputZpAttr(); // pad_const includes the zero point if the tensor uses a zero point. - int32_t pad_const_int = quant_info ? quant_info.getInputZp() : 0; + int32_t pad_const_int = input_zp_attr ? input_zp_attr.getInt() : 0; float pad_const_fp = 0.f; if (auto tensor = pad_op.getPadConst()) { -- cgit v1.2.1