aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-07 20:22:21 +0000
committerTai Ly <tai.ly@arm.com>2024-03-07 15:56:40 -0800
commitf983e51df5030facfd1c5bf59dcc67a32a1913a8 (patch)
treefb7130bcca1f3372b3b239e777e9fef3fe27eb66 /src/TosaSerialize.cpp
parenteecc90165c5d37b75aca83f4d00cb4481872b238 (diff)
downloadtosa_mlir_translator-f983e51df5030facfd1c5bf59dcc67a32a1913a8.tar.gz
[tosa_mlir_translator] ClampOp attributes changes
This patch adjusts serialization/deserialization for ClampOp's new min_val/max_val attributes Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I67f275334fdc8996142b0c4541d55ef65ea66274
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp39
1 files changed, 27 insertions, 12 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 74988c8..54a1d28 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -952,18 +952,33 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>(
mlir::Operation &op) const {
- int32_t min_int =
- op.getAttr("min_int").dyn_cast<mlir::IntegerAttr>().getInt();
- int32_t max_int =
- op.getAttr("max_int").dyn_cast<mlir::IntegerAttr>().getInt();
- float min_fp = op.getAttr("min_fp")
- .dyn_cast<mlir::FloatAttr>()
- .getValue()
- .convertToFloat();
- float max_fp = op.getAttr("max_fp")
- .dyn_cast<mlir::FloatAttr>()
- .getValue()
- .convertToFloat();
+ auto min_val_attr = op.getAttr("min_val");
+ auto max_val_attr = op.getAttr("max_val");
+ float min_fp = 0;
+ float max_fp = 0;
+ int32_t min_int = 0;
+ int32_t max_int = 0;
+
+ mlir::Type input_element_type =
+ llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
+ if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
+ input_element_type)) {
+ input_element_type = quantType.getStorageType();
+ }
+
+ if (input_element_type.isa<mlir::FloatType>()) {
+ min_fp =
+ mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat();
+ max_fp =
+ mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat();
+ } else if (input_element_type.isUnsignedInteger()) {
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt();
+ } else {
+ assert(input_element_type.isa<mlir::IntegerType>());
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt();
+ }
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));