diff options
author | Tai Ly <tai.ly@arm.com> | 2024-03-07 20:22:21 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2024-03-07 15:56:40 -0800 |
commit | f983e51df5030facfd1c5bf59dcc67a32a1913a8 (patch) | |
tree | fb7130bcca1f3372b3b239e777e9fef3fe27eb66 /src/TosaSerialize.cpp | |
parent | eecc90165c5d37b75aca83f4d00cb4481872b238 (diff) | |
download | tosa_mlir_translator-f983e51df5030facfd1c5bf59dcc67a32a1913a8.tar.gz |
[tosa_mlir_translator] ClampOp attributes changes
This patch adjusts serialization/deserialization for
ClampOp's new min_val/max_val attributes
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I67f275334fdc8996142b0c4541d55ef65ea66274
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 39 |
1 files changed, 27 insertions, 12 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 74988c8..54a1d28 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -952,18 +952,33 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>( mlir::Operation &op) const { - int32_t min_int = - op.getAttr("min_int").dyn_cast<mlir::IntegerAttr>().getInt(); - int32_t max_int = - op.getAttr("max_int").dyn_cast<mlir::IntegerAttr>().getInt(); - float min_fp = op.getAttr("min_fp") - .dyn_cast<mlir::FloatAttr>() - .getValue() - .convertToFloat(); - float max_fp = op.getAttr("max_fp") - .dyn_cast<mlir::FloatAttr>() - .getValue() - .convertToFloat(); + auto min_val_attr = op.getAttr("min_val"); + auto max_val_attr = op.getAttr("max_val"); + float min_fp = 0; + float max_fp = 0; + int32_t min_int = 0; + int32_t max_int = 0; + + mlir::Type input_element_type = + llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType(); + if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>( + input_element_type)) { + input_element_type = quantType.getStorageType(); + } + + if (input_element_type.isa<mlir::FloatType>()) { + min_fp = + mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat(); + max_fp = + mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat(); + } else if (input_element_type.isUnsignedInteger()) { + min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt(); + max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt(); + } else { + assert(input_element_type.isa<mlir::IntegerType>()); + min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt(); + max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt(); + } std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); |