diff options
author | Tai Ly <tai.ly@arm.com> | 2023-11-13 17:22:27 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-11-14 23:10:08 +0000 |
commit | d73f3d7de24048f491a9e02ca50be0f069ef10b1 (patch) | |
tree | 1082fcb6a3e5e7ec0e250b56f299709b266c34d0 /src/TosaSerialize.cpp | |
parent | 83d9bd1433cf8ec8ac093edccc6a49ad0255fe73 (diff) | |
download | tosa_mlir_translator-d73f3d7de24048f491a9e02ca50be0f069ef10b1.tar.gz |
[tosa_mlir_translator] Add local_bound support
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I19b86b954574bebdcac42205c65a23e7e43acb72
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 48 |
1 files changed, 42 insertions, 6 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 8aef8fd..a131117 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -790,7 +790,13 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv2DOp>( mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + bool local_bound = + op.hasAttr("local_bound") + ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue() + : false; + + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, + local_bound); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV2D, Attribute_ConvAttribute, &attribute, @@ -826,7 +832,13 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::Conv3DOp>( mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + bool local_bound = + op.hasAttr("local_bound") + ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue() + : false; + + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, + local_bound); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_CONV3D, Attribute_ConvAttribute, &attribute, @@ -862,7 +874,13 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::DepthwiseConv2DOp>( mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); - TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp); + bool local_bound = + op.hasAttr("local_bound") + ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue() + : false; + + TosaConvAttribute attribute(pad, stride, dilation, input_zp, weight_zp, + local_bound); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_DEPTHWISE_CONV2D, Attribute_ConvAttribute, &attribute, @@ -898,8 +916,13 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::TransposeConv2DOp>( mlir::RankedTensorType tensor = op.getOperand(0).getType().cast<mlir::RankedTensorType>(); + bool local_bound = + op.hasAttr("local_bound") + ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue() + : false; + TosaTransposeConvAttribute attribute(out_pad, stride, out_shape, input_zp, - weight_zp); + weight_zp, local_bound); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_TRANSPOSE_CONV2D, Attribute_TransposeConvAttribute, &attribute, @@ -1597,8 +1620,16 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RFFT2dOp>( std::string output_real_name = GetTensorName(op.getResult(0)); std::string output_imag_name = GetTensorName(op.getResult(1)); + bool local_bound = + op.hasAttr("local_bound") + ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue() + : false; + + TosaRFFTAttribute attribute(local_bound); + TosaSerializationOperator *tyop = new TosaSerializationOperator( - Op_RFFT2D, Attribute_NONE, nullptr, std::vector<std::string>{input_name}, + Op_RFFT2D, Attribute_RFFTAttribute, &attribute, + std::vector<std::string>{input_name}, std::vector<std::string>{output_real_name, output_imag_name}); return tyop; @@ -1611,12 +1642,17 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::FFT2dOp>( bool inverse = op.getAttr("inverse").dyn_cast<mlir::BoolAttr>().getValue(); + bool local_bound = + op.hasAttr("local_bound") + ? op.getAttr("local_bound").dyn_cast<mlir::BoolAttr>().getValue() + : false; + std::string input_real_name = GetTensorName(op.getOperand(0)); std::string input_imag_name = GetTensorName(op.getOperand(1)); std::string output_real_name = GetTensorName(op.getResult(0)); std::string output_imag_name = GetTensorName(op.getResult(1)); - TosaFFTAttribute attribute(inverse); + TosaFFTAttribute attribute(inverse, local_bound); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_FFT2D, Attribute_FFTAttribute, &attribute, |