aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2023-01-25 16:57:50 +0000
committerJames Ward <james.ward@arm.com>2023-02-08 17:16:42 +0000
commitf514588912a73bd5eafa49008730ddfc9b3969be (patch)
treed176e143d41a8d790ba5022d9628fd8e071a8c8b /src/TosaSerialize.cpp
parenta0bf7c4fb3456b305fc7696967104270efa82875 (diff)
downloadtosa_mlir_translator-f514588912a73bd5eafa49008730ddfc9b3969be.tar.gz
Remove accum-dtype from all but avg_pool2d & remove zero pad
* Remove zero pad from float attribute serialization * Remove accum-dtype from tensor ops to match specification Signed-off-by: James Ward <james.ward@arm.com> Change-Id: I36e179fa0736f34f2c34309d8372d1cf3ab3c763
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp18
1 files changed, 6 insertions, 12 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index ead54f1..31f1cba 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -571,9 +571,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>(
int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0;
mlir::RankedTensorType tensor =
op.getOperand(0).getType().cast<mlir::RankedTensorType>();
- DType type = Type2AccumDType(tensor.getElementType());
- TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type);
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CONV2D, Attribute_ConvAttribute, &attribute,
@@ -608,9 +607,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>(
int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0;
mlir::RankedTensorType tensor =
op.getOperand(0).getType().cast<mlir::RankedTensorType>();
- DType type = Type2AccumDType(tensor.getElementType());
- TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type);
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CONV3D, Attribute_ConvAttribute, &attribute,
@@ -645,9 +643,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>(
int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0;
mlir::RankedTensorType tensor =
op.getOperand(0).getType().cast<mlir::RankedTensorType>();
- DType type = Type2AccumDType(tensor.getElementType());
- TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, type);
+ TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute,
@@ -682,9 +679,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>(
int32_t weight_zp = quant_info ? quant_info.getWeightZp() : 0;
mlir::RankedTensorType tensor =
op.getOperand(0).getType().cast<mlir::RankedTensorType>();
- DType type = Type2AccumDType(tensor.getElementType());
- TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp, weight_zp, type);
+ TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp, weight_zp);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute,
@@ -711,8 +707,7 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FullyConnectedOp>(
mlir::RankedTensorType tensor =
op.getOperand(0).getType().cast<mlir::RankedTensorType>();
- DType type = Type2AccumDType(tensor.getElementType());
- TosaFullyConnectedAttribute attribute(input_zp, weight_zp, type);
+ TosaFullyConnectedAttribute attribute(input_zp, weight_zp);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_FULLY_CONNECTED, Attribute_FullyConnectedAttribute, &attribute,
@@ -737,9 +732,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::MatMulOp>(
int32_t B_zp = quant_info ? quant_info.getBZp() : 0;
mlir::RankedTensorType tensor =
op.getOperand(0).getType().cast<mlir::RankedTensorType>();
- DType type = Type2AccumDType(tensor.getElementType());
- TosaMatMulAttribute attribute(A_zp, B_zp, type);
+ TosaMatMulAttribute attribute(A_zp, B_zp);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_MATMUL, Attribute_MatMulAttribute, &attribute,