diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/TosaDeserialize.cpp | 5 | ||||
-rw-r--r-- | src/TosaSerialize.cpp | 11 |
2 files changed, 12 insertions, 4 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 0fb02eb..8799028 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -1236,9 +1236,12 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESCALE>( auto double_round = op_builder->getBoolAttr(attr->double_round()); auto per_channel = op_builder->getBoolAttr(attr->per_channel()); + auto input_unsigned = op_builder->getBoolAttr(attr->input_unsigned()); + auto output_unsigned = op_builder->getBoolAttr(attr->output_unsigned()); + mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RescaleOp>( loc, output_type, input_val, input_zp, output_zp, multiplier, shift, - scale32, double_round, per_channel); + scale32, double_round, per_channel, input_unsigned, output_unsigned); block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); } diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 9807a99..263f51c 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -1415,6 +1415,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>( op.getAttr("multiplier").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef(); auto shift = getDenseI8ArrayAttr<int32_t>(op.getAttr("shift")); + bool input_unsigned = + op.getAttr("input_unsigned").dyn_cast<mlir::BoolAttr>().getValue(); + bool output_unsigned = + op.getAttr("output_unsigned").dyn_cast<mlir::BoolAttr>().getValue(); + auto input = op.getOperand(0); auto input_ty = input.getType().cast<mlir::RankedTensorType>(); auto output = op.getResult(0); @@ -1423,9 +1428,9 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>( std::string input_name = GetTensorName(input); std::string output_name = GetTensorName(output); - TosaRescaleAttribute attribute( - input_zp, output_zp, multiplier, shift, scale32, double_round, - per_channel, input_ty.isUnsignedInteger(), output_ty.isUnsignedInteger()); + TosaRescaleAttribute attribute(input_zp, output_zp, multiplier, shift, + scale32, double_round, per_channel, + input_unsigned, output_unsigned); TosaSerializationOperator *tyop = new TosaSerializationOperator( Op_RESCALE, Attribute_RescaleAttribute, &attribute, |