diff options
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 18 |
1 files changed, 6 insertions, 12 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index ead54f1..31f1cba 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -571,9 +571,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>( int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV2D, Attribute_ConvAttribute, &attribute, @@ -608,9 +607,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>( int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV3D, Attribute_ConvAttribute, &attribute, @@ -645,9 +643,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>( int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, @@ -682,9 +679,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>( int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp, weight_zp, type); + TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, @@ -711,8 +707,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FullyConnectedOp>( mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaFullyConnectedAttribute attribute(input_zp, weight_zp, type); + TosaFullyConnectedAttribute attribute(input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_FULLY_CONNECTED, Attribute_FullyConnectedAttribute, &attribute, @@ -737,9 +732,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::MatMulOp>( int32_t B_zp = quant_info ? quant_info.getBZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaMatMulAttribute attribute(A_zp, B_zp, type); + TosaMatMulAttribute attribute(A_zp, B_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_MATMUL, Attribute_MatMulAttribute, &attribute, |