From f514588912a73bd5eafa49008730ddfc9b3969be Mon Sep 17 00:00:00 2001 From: James Ward Date: Wed, 25 Jan 2023 16:57:50 +0000 Subject: Remove accum-dtype from all but avg_pool2d & remove zero pad * Remove zero pad from float attribute serialization * Remove accum-dtype from tensor ops to match specification Signed-off-by: James Ward Change-Id: I36e179fa0736f34f2c34309d8372d1cf3ab3c763 --- src/TosaSerialize.cpp | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) (limited to 'src/TosaSerialize.cpp') diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index ead54f1..31f1cba 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -571,9 +571,8 @@ TosaSerializationOperatorBuilder::build( int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV2D, Attribute_ConvAttribute, &attribute, @@ -608,9 +607,8 @@ TosaSerializationOperatorBuilder::build( int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV3D, Attribute_ConvAttribute, &attribute, @@ -645,9 +643,8 @@ TosaSerializationOperatorBuilder::build( int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, @@ -682,9 +679,8 @@ TosaSerializationOperatorBuilder::build( int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp, weight_zp, type); + TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, @@ -711,8 +707,7 @@ TosaSerializationOperatorBuilder::build( mlir::RankedTensorType tensor = op.getOperand(0).getType().cast(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaFullyConnectedAttribute attribute(input_zp, weight_zp, type); + TosaFullyConnectedAttribute attribute(input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_FULLY_CONNECTED, Attribute_FullyConnectedAttribute, &attribute, @@ -737,9 +732,8 @@ TosaSerializationOperatorBuilder::build( int32_t B_zp = quant_info ? quant_info.getBZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaMatMulAttribute attribute(A_zp, B_zp, type); + TosaMatMulAttribute attribute(A_zp, B_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_MATMUL, Attribute_MatMulAttribute, &attribute, -- cgit v1.2.1