aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2023-11-14 18:45:33 +0000
committerJames Ward <james.ward@arm.com>2023-11-27 10:16:00 +0000
commitfc32f56a067c526238c15de097fe78fdcab95cb5 (patch)
treec1a4041b7ec9079ab6d1fdce8fcc6395538662e1
parent546e9990065804f6304a216b42468bf44c8c1036 (diff)
downloadtosa_mlir_translator-fc32f56a067c526238c15de097fe78fdcab95cb5.tar.gz
Add Rescale Attribute changes
Signed-off-by: James Ward <james.ward@arm.com> Change-Id: I8ac71800d922526aad0a7c351ad1943481208cc2
-rw-r--r--src/TosaDeserialize.cpp5
-rw-r--r--src/TosaSerialize.cpp11
2 files changed, 12 insertions, 4 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 0fb02eb..8799028 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -1236,9 +1236,12 @@ std::vector<mlir::Value> TosaMlirOperatorBuilder::build<Op_RESCALE>(
auto double_round = op_builder->getBoolAttr(attr->double_round());
auto per_channel = op_builder->getBoolAttr(attr->per_channel());
+ auto input_unsigned = op_builder->getBoolAttr(attr->input_unsigned());
+ auto output_unsigned = op_builder->getBoolAttr(attr->output_unsigned());
+
mlir::Operation *mlir_op = op_builder->create<mlir::tosa::RescaleOp>(
loc, output_type, input_val, input_zp, output_zp, multiplier, shift,
- scale32, double_round, per_channel);
+ scale32, double_round, per_channel, input_unsigned, output_unsigned);
block->push_back(mlir_op);
return std::vector<mlir::Value>({mlir_op->getResult(0)});
}
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 9807a99..263f51c 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -1415,6 +1415,11 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>(
op.getAttr("multiplier").dyn_cast<mlir::DenseI32ArrayAttr>().asArrayRef();
auto shift = getDenseI8ArrayAttr<int32_t>(op.getAttr("shift"));
+ bool input_unsigned =
+ op.getAttr("input_unsigned").dyn_cast<mlir::BoolAttr>().getValue();
+ bool output_unsigned =
+ op.getAttr("output_unsigned").dyn_cast<mlir::BoolAttr>().getValue();
+
auto input = op.getOperand(0);
auto input_ty = input.getType().cast<mlir::RankedTensorType>();
auto output = op.getResult(0);
@@ -1423,9 +1428,9 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>(
std::string input_name = GetTensorName(input);
std::string output_name = GetTensorName(output);
- TosaRescaleAttribute attribute(
- input_zp, output_zp, multiplier, shift, scale32, double_round,
- per_channel, input_ty.isUnsignedInteger(), output_ty.isUnsignedInteger());
+ TosaRescaleAttribute attribute(input_zp, output_zp, multiplier, shift,
+ scale32, double_round, per_channel,
+ input_unsigned, output_unsigned);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_RESCALE, Attribute_RescaleAttribute, &attribute,