diff options
author | James Ward <james.ward@arm.com> | 2023-01-25 16:57:50 +0000 |
---|---|---|
committer | James Ward <james.ward@arm.com> | 2023-02-08 17:16:42 +0000 |
commit | f514588912a73bd5eafa49008730ddfc9b3969be (patch) | |
tree | d176e143d41a8d790ba5022d9628fd8e071a8c8b | |
parent | a0bf7c4fb3456b305fc7696967104270efa82875 (diff) | |
download | tosa_mlir_translator-f514588912a73bd5eafa49008730ddfc9b3969be.tar.gz |
Remove accum-dtype from all but avg_pool2d & remove zero pad
* Remove zero pad from float attribute serialization
* Remove accum-dtype from tensor ops to match specification
Signed-off-by: James Ward <james.ward@arm.com>
Change-Id: I36e179fa0736f34f2c34309d8372d1cf3ab3c763
-rw-r--r-- | src/TosaSerialize.cpp | 18 | ||||
m--------- | third_party/serialization_lib | 0 |
2 files changed, 6 insertions, 12 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index ead54f1..31f1cba 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -571,9 +571,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>( int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV2D, Attribute_ConvAttribute, &attribute, @@ -608,9 +607,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>( int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV3D, Attribute_ConvAttribute, &attribute, @@ -645,9 +643,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>( int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type); + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, @@ -682,9 +679,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>( int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp, weight_zp, type); + TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, @@ -711,8 +707,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FullyConnectedOp>( mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaFullyConnectedAttribute attribute(input_zp, weight_zp, type); + TosaFullyConnectedAttribute attribute(input_zp, weight_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_FULLY_CONNECTED, Attribute_FullyConnectedAttribute, &attribute, @@ -737,9 +732,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::MatMulOp>( int32_t B_zp = quant_info ? quant_info.getBZp() : 0; mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - DType type = Type2AccumDType(tensor.getElementType()); - TosaMatMulAttribute attribute(A_zp, B_zp, type); + TosaMatMulAttribute attribute(A_zp, B_zp); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_MATMUL, Attribute_MatMulAttribute, &attribute, diff --git a/third_party/serialization_lib b/third_party/serialization_lib -Subproject c15f7d52aa4f360eba2344449baa418b7608ac7 +Subproject 80905bba37ce55e8db293b1405a78b63dc1855c |