diff options
author | Tai Ly <tai.ly@arm.com> | 2024-02-13 19:38:13 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-02-22 02:17:31 +0000 |
commit | f5645effe818c2b0cb0124597bc761850616ba76 (patch) | |
tree | 6cd4abdc04a80e5a2f8875e0d6c328a331fae11a /src/TosaSerialize.cpp | |
parent | 4bc57740f03704179d2611f5d41572612bc42e9a (diff) | |
download | tosa_mlir_translator-f5645effe818c2b0cb0124597bc761850616ba76.tar.gz |
[tosa_mlir_translator] Refactor QuantizationAttr
changes to serialization/deserialization due to removal of
quantization attr in TOSA dialect
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I0903d2c1c62bc50822e6c08bb869ec135c986ff3
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 120 |
1 files changed, 66 insertions, 54 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 05c7812..0f5056d 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -471,14 +471,16 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = op.getAttrOfType<mlir::tosa::UnaryOpQuantizationAttr>( - "quantization_info"); - - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t output_zp = quant_info ? quant_info.getOutputZp() : 0; + int32_t input_zp = + op.hasAttr("input_zp") + ? input_zp = op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t output_zp = + op.hasAttr("output_zp") + ? output_zp = + op.getAttr("output_zp").cast<mlir::IntegerAttr>().getInt() + : 0; - mlir::RankedTensorType tensor = - op.getOperand(0).getType().cast<mlir::RankedTensorType>(); TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp, accum_dtype); @@ -717,13 +719,14 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info"); - - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; - mlir::RankedTensorType tensor = - op.getOperand(0).getType().cast<mlir::RankedTensorType>(); + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt() + : 0; bool local_bound = op.hasAttr("local_bound") @@ -759,13 +762,14 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info"); - - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; - mlir::RankedTensorType tensor = - op.getOperand(0).getType().cast<mlir::RankedTensorType>(); + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt() + : 0; bool local_bound = op.hasAttr("local_bound") @@ -801,13 +805,14 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info"); - - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; - mlir::RankedTensorType tensor = - op.getOperand(0).getType().cast<mlir::RankedTensorType>(); + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt() + : 0; bool local_bound = op.hasAttr("local_bound") @@ -843,11 +848,15 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info"); + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt() + : 0; - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); @@ -876,14 +885,15 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FullyConnectedOp>( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = - op.getAttrOfType<mlir::tosa::ConvOpQuantizationAttr>("quantization_info"); - - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; + int32_t input_zp = + op.hasAttr("input_zp") + ? op.getAttr("input_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t weight_zp = + op.hasAttr("weight_zp") + ? op.getAttr("weight_zp").cast<mlir::IntegerAttr>().getInt() + : 0; - mlir::RankedTensorType tensor = - op.getOperand(0).getType().cast<mlir::RankedTensorType>(); TosaFullyConnectedAttribute attribute(input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( @@ -902,13 +912,12 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::MatMulOp>( std::string input1_name = GetTensorName(op.getOperand(1)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = op.getAttrOfType<mlir::tosa::MatMulOpQuantizationAttr>( - "quantization_info"); - - int32_t A_zp = quant_info ? quant_info.getAZp() : 0; - int32_t B_zp = quant_info ? quant_info.getBZp() : 0; - mlir::RankedTensorType tensor = - op.getOperand(0).getType().cast<mlir::RankedTensorType>(); + int32_t A_zp = op.hasAttr("a_zp") + ? op.getAttr("a_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t B_zp = op.hasAttr("b_zp") + ? op.getAttr("b_zp").cast<mlir::IntegerAttr>().getInt() + : 0; TosaMatMulAttribute attribute(A_zp, B_zp); @@ -1035,13 +1044,16 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::NegateOp>( std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - auto quant_info = op.getAttrOfType<mlir::tosa::UnaryOpQuantizationAttr>( - "quantization_info"); - - int32_t input_zp = quant_info ? quant_info.getInputZp() : 0; - int32_t output_zp = quant_info ? quant_info.getOutputZp() : 0; + int32_t input1_zp = + op.hasAttr("input1_zp") + ? op.getAttr("input1_zp").cast<mlir::IntegerAttr>().getInt() + : 0; + int32_t output_zp = + op.hasAttr("output_zp") + ? op.getAttr("output_zp").cast<mlir::IntegerAttr>().getInt() + : 0; - TosaNegateAttribute attribute(input_zp, output_zp); + TosaNegateAttribute attribute(input1_zp, output_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_NEGATE, Attribute_NegateAttribute, &attribute, @@ -1075,9 +1087,9 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>( std::string output_name = GetTensorName(op.getResult(0)); auto pad_op = llvm::cast<mlir::tosa::PadOp>(op); - auto quant_info = pad_op.getQuantizationInfoAttr(); + auto input_zp_attr = pad_op.getInputZpAttr(); // pad_const includes the zero point if the tensor uses a zero point. - int32_t pad_const_int = quant_info ? quant_info.getInputZp() : 0; + int32_t pad_const_int = input_zp_attr ? input_zp_attr.getInt() : 0; float pad_const_fp = 0.f; if (auto tensor = pad_op.getPadConst()) { |