diff options
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r-- | src/TosaDeserialize.cpp | 45 |
1 files changed, 24 insertions, 21 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 21798f3..9649644 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -673,22 +673,21 @@ TosaMlirOperatorBuilder::BuildConvOp(TosaSerializationOperator *op) const { BuildDenseI64ArrayAttr(op_builder, attr->dilation()); auto input_zp = attr->input_zp(); auto weight_zp = attr->weight_zp(); + bool local_bound = attr->local_bound(); // quantizationattr is required for quantized type, and not allowed for float // type mlir::Operation *mlir_op; if (output_type.getElementType().isa<mlir::FloatType>()) { assert(input_zp == 0 && weight_zp == 0); - mlir_op = - op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val, - input2_val, pad, stride, dilation); - } else { - auto quant = op_builder->getAttr<mlir::tosa::ConvOpQuantizationAttr>( - input_zp, weight_zp); - mlir_op = - op_builder->create<MLIR_OP>(loc, output_type, input0_val, input1_val, - input2_val, pad, stride, dilation, quant); } + + auto quant = op_builder->getAttr<mlir::tosa::ConvOpQuantizationAttr>( + input_zp, weight_zp); + mlir_op = op_builder->create<MLIR_OP>(loc, output_type, input0_val, + input1_val, input2_val, pad, stride, + dilation, quant, local_bound); + block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); } @@ -726,22 +725,20 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_TRANSPOSE_CONV2D>( BuildDenseI64ArrayAttr(op_builder, attr->output_shape()); auto input_zp = attr->input_zp(); auto weight_zp = attr->weight_zp(); + bool local_bound = attr->local_bound(); // quantizationattr is required for quantized type, and not allowed for float // type mlir::Operation *mlir_op; if (output_type.getElementType().isa<mlir::FloatType>()) { assert(input_zp == 0 && weight_zp == 0); - mlir_op = op_builder->create<mlir::tosa::TransposeConv2DOp>( - loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, - output_shape); - } else { - auto quant = op_builder->getAttr<mlir::tosa::ConvOpQuantizationAttr>( - input_zp, weight_zp); - mlir_op = op_builder->create<mlir::tosa::TransposeConv2DOp>( - loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, - output_shape, quant); } + auto quant = op_builder->getAttr<mlir::tosa::ConvOpQuantizationAttr>( + input_zp, weight_zp); + mlir_op = op_builder->create<mlir::tosa::TransposeConv2DOp>( + loc, output_type, input0_val, input1_val, input2_val, out_pad, stride, + output_shape, quant, local_bound); + block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); } @@ -1283,10 +1280,14 @@ TosaMlirOperatorBuilder::build<Op_RFFT2D>(TosaSerializationOperator *op) const { mlir::RankedTensorType output1_type = tensor_type_map->at(op->GetOutputTensorNames()[1]); assert(op->GetAttributeType() == - Attribute_NONE); // double check that there is no attribute + Attribute_RFFTAttribute); // double check attribute type + TosaRFFTAttribute *attr = + static_cast<TosaRFFTAttribute *>(op->GetAttribute()); + + bool local_bound = attr->local_bound(); mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RFFT2dOp>( - loc, output0_type, output1_type, input_val); + loc, output0_type, output1_type, input_val, local_bound); block->push_back(mlir_op); return std::vector<mlir::Value>( {mlir_op->getResult(0), mlir_op->getResult(1)}); @@ -1305,9 +1306,11 @@ TosaMlirOperatorBuilder::build<Op_FFT2D>(TosaSerializationOperator *op) const { assert(op->GetAttributeType() == Attribute_FFTAttribute); TosaFFTAttribute *attr = static_cast<TosaFFTAttribute *>(op->GetAttribute()); auto inverse = op_builder->getBoolAttr(attr->inverse()); + auto local_bound = op_builder->getBoolAttr(attr->local_bound()); mlir::Operation *mlir_op = op_builder->create<mlir::tosa::FFT2dOp>( - loc, output0_type, output1_type, input0_val, input1_val, inverse); + loc, output0_type, output1_type, input0_val, input1_val, inverse, + local_bound); block->push_back(mlir_op); return std::vector<mlir::Value>( {mlir_op->getResult(0), mlir_op->getResult(1)}); |