From 728fd31bc08b12255500779a3a33b9e05597e41f Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Thu, 9 Jun 2022 16:54:48 -0700 Subject: Remove quantization info This adapts the mlir-translator to the updated serialization library which does not have quantization attributes in the schema. Change-Id: I321845c735426b2325590e0241e67242c31064a5 Signed-off-by: Eric Kunze --- src/TosaSerialize.cpp | 266 +++++++++++------------------------------- third_party/serialization_lib | 2 +- 2 files changed, 71 insertions(+), 197 deletions(-) diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 60393ae..a6ea5ff 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -94,72 +94,6 @@ DType Type2DType(mlir::Type element_type) { return DType_UNKNOWN; } -int GetQuantizedParameter(mlir::Type type, std::vector &scale, - std::vector &zeropoint, - int32_t &quantized_dimension, int64_t &quant_min, - int64_t &quant_max) { - if (auto qtype = type.dyn_cast()) { - scale.push_back(qtype.getScale()); - zeropoint.push_back(qtype.getZeroPoint()); - quantized_dimension = 0; - - quant_min = qtype.getStorageTypeMin(); - quant_max = qtype.getStorageTypeMax(); - } else if (auto qtype = - type.dyn_cast()) { - scale.assign(qtype.getScales().begin(), qtype.getScales().end()); - zeropoint.assign(qtype.getZeroPoints().begin(), - qtype.getZeroPoints().end()); - quantized_dimension = qtype.getQuantizedDimension(); - - quant_min = qtype.getStorageTypeMin(); - quant_max = qtype.getStorageTypeMax(); - } else { - return 1; - } - - return 0; -} - -TosaQuantInfoBase * -GetUnaryQuantInfo(mlir::tosa::UnaryOpQuantizationAttr quant_info) { - int32_t input_zp = quant_info.input_zp().getInt(); - int32_t output_zp = quant_info.output_zp().getInt(); - - TosaQuantInfoBase *qinfo = new TosaUnaryQuantInfo(input_zp, output_zp); - - return qinfo; -} - -TosaQuantInfoBase * -GetConvQuantInfo(mlir::tosa::ConvOpQuantizationAttr quant_info) { - int32_t input_zp = quant_info.input_zp().getInt(); - int32_t weight_zp = quant_info.weight_zp().getInt(); - - TosaQuantInfoBase *qinfo = new TosaConvQuantInfo(input_zp, weight_zp); - - return qinfo; -} - -TosaQuantInfoBase * -GetPadQuantInfo(mlir::tosa::PadOpQuantizationAttr quant_info) { - int32_t input_zp = quant_info.input_zp().getInt(); - - TosaQuantInfoBase *qinfo = new TosaPadQuantInfo(input_zp); - - return qinfo; -} - -TosaQuantInfoBase * -GetMatMulQuantInfo(mlir::tosa::MatMulOpQuantizationAttr quant_info) { - int32_t a_zp = quant_info.a_zp().getInt(); - int32_t b_zp = quant_info.b_zp().getInt(); - - TosaQuantInfoBase *qinfo = new TosaMatMulQuantInfo(a_zp, b_zp); - - return qinfo; -} - class TosaSerializationBlockBuilder; class TosaSerializationOperatorBuilder { @@ -267,26 +201,19 @@ TosaSerializationOperatorBuilder::BuildPoolOpFromMlirOp(mlir::Operation &op, std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); - TosaPoolAttribute attribute(pad, kernel, stride); auto quant_info = op.getAttrOfType( "quantization_info"); - QuantInfo qinfo_type; - TosaQuantInfoBase *qinfo; - if (quant_info) { - qinfo_type = QuantInfo_UnaryQuantInfo; - qinfo = GetUnaryQuantInfo(quant_info); - } else { - qinfo_type = QuantInfo_NONE; - qinfo = new TosaNoneQuantInfo(); - } + + int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; + int32_t output_zp = quant_info ? quant_info.output_zp().getInt() : 0; + + TosaPoolAttribute attribute(pad, kernel, stride, input_zp, output_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( - opcode, Attribute_PoolAttribute, &attribute, qinfo_type, qinfo, + opcode, Attribute_PoolAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); - delete qinfo; - return tyop; } @@ -298,7 +225,7 @@ TosaSerializationOperatorBuilder::BuildEwiseBinaryOpFromMlirOp( std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( - opcode, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + opcode, Attribute_NONE, nullptr, std::vector{input0_name, input1_name}, std::vector{output_name}); @@ -312,7 +239,7 @@ TosaSerializationOperatorBuilder::BuildEwiseUnaryOpFromMlirOp( std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( - opcode, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + opcode, Attribute_NONE, nullptr, std::vector{input_name}, std::vector{output_name}); @@ -329,7 +256,7 @@ TosaSerializationOperatorBuilder::BuildReductionOpFromMlirOp( TosaAxisAttribute attribute(axis); TosaSerializationOperator *tyop = new TosaSerializationOperator( - opcode, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr, + opcode, Attribute_AxisAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); @@ -560,7 +487,7 @@ TosaSerializationOperatorBuilder::build( ts->SetData(u8_data); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_CONST, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + Op_CONST, Attribute_NONE, nullptr, std::vector{}, std::vector{output_name}); return tyop; @@ -597,27 +524,20 @@ TosaSerializationOperatorBuilder::build( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - TosaConvAttribute attribute(pad, stride, dilation); auto quant_info = op.getAttrOfType("quantization_info"); - QuantInfo qinfo_type; - TosaQuantInfoBase *qinfo; - if (quant_info) { - qinfo_type = QuantInfo_ConvQuantInfo; - qinfo = GetConvQuantInfo(quant_info); - } else { - qinfo_type = QuantInfo_NONE; - qinfo = new TosaNoneQuantInfo(); - } + + int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; + int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0; + + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_CONV2D, Attribute_ConvAttribute, &attribute, qinfo_type, qinfo, + Op_CONV2D, Attribute_ConvAttribute, &attribute, std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); - delete qinfo; - return tyop; } @@ -652,27 +572,19 @@ TosaSerializationOperatorBuilder::build( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - TosaConvAttribute attribute(pad, stride, dilation); - auto quant_info = op.getAttrOfType("quantization_info"); - QuantInfo qinfo_type; - TosaQuantInfoBase *qinfo; - if (quant_info) { - qinfo_type = QuantInfo_ConvQuantInfo; - qinfo = GetConvQuantInfo(quant_info); - } else { - qinfo_type = QuantInfo_NONE; - qinfo = new TosaNoneQuantInfo(); - } + + int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; + int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0; + + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, qinfo_type, - qinfo, std::vector{input0_name, input1_name, input2_name}, + Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, + std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); - delete qinfo; - return tyop; } @@ -708,28 +620,19 @@ TosaSerializationOperatorBuilder::build( std::string input2_name = GetTensorName(op.getOperand(2)); std::string output_name = GetTensorName(op.getResult(0)); - TosaTransposeConvAttribute attribute(outpad, stride, output_shape); - auto quant_info = op.getAttrOfType("quantization_info"); - QuantInfo qinfo_type; - TosaQuantInfoBase *qinfo; - if (quant_info) { - qinfo_type = QuantInfo_ConvQuantInfo; - qinfo = GetConvQuantInfo(quant_info); - } else { - qinfo_type = QuantInfo_NONE; - qinfo = new TosaNoneQuantInfo(); - } + + int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; + int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0; + + TosaTransposeConvAttribute attribute(outpad, stride, output_shape, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, - qinfo_type, qinfo, std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); - delete qinfo; - return tyop; } @@ -744,23 +647,16 @@ TosaSerializationOperatorBuilder::build( auto quant_info = op.getAttrOfType("quantization_info"); - QuantInfo qinfo_type; - TosaQuantInfoBase *qinfo; - if (quant_info) { - qinfo_type = QuantInfo_ConvQuantInfo; - qinfo = GetConvQuantInfo(quant_info); - } else { - qinfo_type = QuantInfo_NONE; - qinfo = new TosaNoneQuantInfo(); - } + + int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; + int32_t weight_zp = quant_info ? quant_info.weight_zp().getInt() : 0; + TosaFullyConnectedAttribute attribute(input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_FULLY_CONNECTED, Attribute_NONE, nullptr, qinfo_type, qinfo, + Op_FULLY_CONNECTED, Attribute_FullyConnectedAttribute, &attribute, std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); - delete qinfo; - return tyop; } @@ -774,23 +670,17 @@ TosaSerializationOperatorBuilder::build( auto quant_info = op.getAttrOfType( "quantization_info"); - QuantInfo qinfo_type; - TosaQuantInfoBase *qinfo; - if (quant_info) { - qinfo_type = QuantInfo_MatMulQuantInfo; - qinfo = GetMatMulQuantInfo(quant_info); - } else { - qinfo_type = QuantInfo_NONE; - qinfo = new TosaNoneQuantInfo(); - } + + int32_t A_zp = quant_info ? quant_info.a_zp().getInt() : 0; + int32_t B_zp = quant_info ? quant_info.b_zp().getInt() : 0; + + TosaMatMulAttribute attribute(A_zp, B_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_MATMUL, Attribute_NONE, nullptr, qinfo_type, qinfo, + Op_MATMUL, Attribute_MatMulAttribute, &attribute, std::vector{input0_name, input1_name}, std::vector{output_name}); - delete qinfo; - return tyop; } @@ -804,7 +694,7 @@ TosaSerializationOperatorBuilder::build( std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_SELECT, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + Op_SELECT, Attribute_NONE, nullptr, std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); @@ -834,7 +724,7 @@ TosaSerializationOperatorBuilder::build( TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_CLAMP, Attribute_ClampAttribute, &attribute, QuantInfo_NONE, nullptr, + Op_CLAMP, Attribute_ClampAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); @@ -853,7 +743,7 @@ TosaSerializationOperatorBuilder::build( TosaAxisAttribute attribute(axis); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_ARGMAX, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr, + Op_ARGMAX, Attribute_AxisAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); @@ -877,7 +767,7 @@ TosaSerializationOperatorBuilder::build( TosaAxisAttribute attribute(axis); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_CONCAT, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr, + Op_CONCAT, Attribute_AxisAttribute, &attribute, inputs, std::vector{output_name}); return tyop; @@ -892,18 +782,14 @@ TosaSerializationOperatorBuilder::build( auto quant_info = op.getAttrOfType( "quantization_info"); - QuantInfo qinfo_type; - TosaQuantInfoBase *qinfo; - if (quant_info) { - qinfo_type = QuantInfo_UnaryQuantInfo; - qinfo = GetUnaryQuantInfo(quant_info); - } else { - qinfo_type = QuantInfo_NONE; - qinfo = new TosaNoneQuantInfo(); - } + + int32_t input_zp = quant_info ? quant_info.input_zp().getInt() : 0; + int32_t output_zp = quant_info ? quant_info.output_zp().getInt() : 0; + + TosaNegateAttribute attribute(input_zp, output_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_NEGATE, Attribute_NONE, nullptr, qinfo_type, qinfo, + Op_NEGATE, Attribute_NegateAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); return tyop; @@ -926,8 +812,8 @@ TosaSerializationOperatorBuilder::build( TosaReshapeAttribute attribute(shape); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_RESHAPE, Attribute_ReshapeAttribute, &attribute, QuantInfo_NONE, - nullptr, std::vector{input_name}, + Op_RESHAPE, Attribute_ReshapeAttribute, &attribute, + std::vector{input_name}, std::vector{output_name}); return tyop; @@ -954,20 +840,8 @@ TosaSerializationOperatorBuilder::build( TosaPadAttribute attribute(paddings, 0 /* pad_const_int */, 0.0f /* pad_const_fp */); - auto quant_info = - op.getAttrOfType("quantization_info"); - QuantInfo qinfo_type; - TosaQuantInfoBase *qinfo; - if (quant_info) { - qinfo_type = QuantInfo_PadQuantInfo; - qinfo = GetPadQuantInfo(quant_info); - } else { - qinfo_type = QuantInfo_NONE; - qinfo = new TosaNoneQuantInfo(); - } - TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_PAD, Attribute_PadAttribute, &attribute, qinfo_type, qinfo, + Op_PAD, Attribute_PadAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); return tyop; @@ -994,8 +868,8 @@ TosaSerializationOperatorBuilder::build( TosaTransposeAttribute attribute(perm); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_TRANSPOSE, Attribute_TransposeAttribute, &attribute, QuantInfo_NONE, - nullptr, std::vector{input_name}, + Op_TRANSPOSE, Attribute_TransposeAttribute, &attribute, + std::vector{input_name}, std::vector{output_name}); return tyop; @@ -1023,7 +897,7 @@ TosaSerializationOperatorBuilder::build( std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_SLICE, Attribute_SliceAttribute, &attribute, QuantInfo_NONE, nullptr, + Op_SLICE, Attribute_SliceAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); @@ -1047,7 +921,7 @@ TosaSerializationOperatorBuilder::build( TosaTileAttribute attribute(multiples); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_TILE, Attribute_TileAttribute, &attribute, QuantInfo_NONE, nullptr, + Op_TILE, Attribute_TileAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); @@ -1063,7 +937,7 @@ TosaSerializationOperatorBuilder::build( std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_GATHER, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + Op_GATHER, Attribute_NONE, nullptr, std::vector{input0_name, input1_name}, std::vector{output_name}); @@ -1080,7 +954,7 @@ TosaSerializationOperatorBuilder::build( std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_SCATTER, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + Op_SCATTER, Attribute_NONE, nullptr, std::vector{input0_name, input1_name, input2_name}, std::vector{output_name}); @@ -1144,7 +1018,7 @@ TosaSerializationOperatorBuilder::build( offset_fp, mode); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_RESIZE, Attribute_ResizeAttribute, &attribute, QuantInfo_NONE, nullptr, + Op_RESIZE, Attribute_ResizeAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); @@ -1163,7 +1037,7 @@ TosaSerializationOperatorBuilder::build( TosaAxisAttribute attribute(axis); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_REVERSE, Attribute_AxisAttribute, &attribute, QuantInfo_NONE, nullptr, + Op_REVERSE, Attribute_AxisAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); @@ -1183,7 +1057,7 @@ TosaSerializationOperatorBuilder::build( TosaMulAttribute attribute(shift); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_MUL, Attribute_MulAttribute, &attribute, QuantInfo_NONE, nullptr, + Op_MUL, Attribute_MulAttribute, &attribute, std::vector{input0_name, input1_name}, std::vector{output_name}); @@ -1204,7 +1078,7 @@ TosaSerializationOperatorBuilder::build( TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_ARITHMETIC_RIGHT_SHIFT, Attribute_ArithmeticRightShiftAttribute, - &attribute, QuantInfo_NONE, nullptr, + &attribute, std::vector{input0_name, input1_name}, std::vector{output_name}); @@ -1232,7 +1106,7 @@ TosaSerializationOperatorBuilder::build( TosaTableAttribute attribute(table); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_TABLE, Attribute_TableAttribute, &attribute, QuantInfo_NONE, nullptr, + Op_TABLE, Attribute_TableAttribute, &attribute, std::vector{input_name}, std::vector{output_name}); @@ -1273,8 +1147,8 @@ TosaSerializationOperatorBuilder::build( scale32, double_round, per_channel); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_RESCALE, Attribute_RescaleAttribute, &attribute, QuantInfo_NONE, - nullptr, std::vector{input_name}, + Op_RESCALE, Attribute_RescaleAttribute, &attribute, + std::vector{input_name}, std::vector{output_name}); return tyop; @@ -1288,7 +1162,7 @@ TosaSerializationOperatorBuilder::build( std::string output_name = GetTensorName(op.getResult(0)); TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_CUSTOM, Attribute_NONE, nullptr, QuantInfo_NONE, nullptr, + Op_CUSTOM, Attribute_NONE, nullptr, std::vector{input_name}, std::vector{output_name}); @@ -1362,8 +1236,8 @@ TosaSerializationOperatorBuilder::build( } TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_COND_IF, Attribute_CondIfAttribute, &attribute, QuantInfo_NONE, - nullptr, input_names, output_names); + Op_COND_IF, Attribute_CondIfAttribute, &attribute, + input_names, output_names); return tyop; } @@ -1435,8 +1309,8 @@ TosaSerializationOperatorBuilder::build( } TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute, QuantInfo_NONE, - nullptr, input_names, output_names); + Op_WHILE_LOOP, Attribute_WhileLoopAttribute, &attribute, + input_names, output_names); return tyop; } diff --git a/third_party/serialization_lib b/third_party/serialization_lib index 7ffa1ff..bdcc3fe 160000 --- a/third_party/serialization_lib +++ b/third_party/serialization_lib @@ -1 +1 @@ -Subproject commit 7ffa1ff137b573e775892836821976e190f28687 +Subproject commit bdcc3fee1b8bf55aac50e060115b92a1ccf9741c -- cgit v1.2.1