aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-07 20:22:21 +0000
committerTai Ly <tai.ly@arm.com>2024-03-07 15:56:40 -0800
commitf983e51df5030facfd1c5bf59dcc67a32a1913a8 (patch)
treefb7130bcca1f3372b3b239e777e9fef3fe27eb66
parenteecc90165c5d37b75aca83f4d00cb4481872b238 (diff)
downloadtosa_mlir_translator-f983e51df5030facfd1c5bf59dcc67a32a1913a8.tar.gz
[tosa_mlir_translator] ClampOp attributes changes
This patch adjusts serialization/deserialization for ClampOp's new min_val/max_val attributes Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I67f275334fdc8996142b0c4541d55ef65ea66274
-rw-r--r--src/TosaDeserialize.cpp39
-rw-r--r--src/TosaSerialize.cpp39
2 files changed, 61 insertions, 17 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 9660833..d69f005 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -937,13 +937,42 @@ TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const {
TosaClampAttribute *attr =
static_cast<TosaClampAttribute *>(op->GetAttribute());
- auto min_int = op_builder->getI64IntegerAttr(attr->min_int());
- auto max_int = op_builder->getI64IntegerAttr(attr->max_int());
- auto min_fp = op_builder->getF32FloatAttr(attr->min_fp());
- auto max_fp = op_builder->getF32FloatAttr(attr->max_fp());
+ mlir::Type input_element_type =
+ llvm::cast<mlir::ShapedType>(input_val.getType()).getElementType();
+ if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
+ input_element_type)) {
+ input_element_type = quantType.getStorageType();
+ }
+
+ mlir::Attribute min_val_attr, max_val_attr;
+ if (input_element_type.isa<mlir::FloatType>()) {
+ min_val_attr = op_builder->getFloatAttr(input_element_type, attr->min_fp());
+ max_val_attr = op_builder->getFloatAttr(input_element_type, attr->max_fp());
+ } else if (input_element_type.isUnsignedInteger()) {
+ if (input_element_type.isUnsignedInteger(8)) {
+ uint8_t min_val = attr->min_int();
+ uint8_t max_val = attr->max_int();
+ min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
+ max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
+ } else if (input_element_type.isUnsignedInteger(16)) {
+ uint16_t min_val = attr->min_int();
+ uint16_t max_val = attr->max_int();
+ min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
+ max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
+ } else {
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " contains unsupported unsigned int element data type.\n";
+ return {};
+ }
+ } else {
+ min_val_attr =
+ op_builder->getIntegerAttr(input_element_type, attr->min_int());
+ max_val_attr =
+ op_builder->getIntegerAttr(input_element_type, attr->max_int());
+ }
mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ClampOp>(
- loc, output_type, input_val, min_int, max_int, min_fp, max_fp);
+ loc, output_type, input_val, min_val_attr, max_val_attr);
block->push_back(mlir_op);
return std::vector<mlir::Value>({mlir_op->getResult(0)});
}
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 74988c8..54a1d28 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -952,18 +952,33 @@ template <>
TosaSerializationOperator *
TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>(
mlir::Operation &op) const {
- int32_t min_int =
- op.getAttr("min_int").dyn_cast<mlir::IntegerAttr>().getInt();
- int32_t max_int =
- op.getAttr("max_int").dyn_cast<mlir::IntegerAttr>().getInt();
- float min_fp = op.getAttr("min_fp")
- .dyn_cast<mlir::FloatAttr>()
- .getValue()
- .convertToFloat();
- float max_fp = op.getAttr("max_fp")
- .dyn_cast<mlir::FloatAttr>()
- .getValue()
- .convertToFloat();
+ auto min_val_attr = op.getAttr("min_val");
+ auto max_val_attr = op.getAttr("max_val");
+ float min_fp = 0;
+ float max_fp = 0;
+ int32_t min_int = 0;
+ int32_t max_int = 0;
+
+ mlir::Type input_element_type =
+ llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
+ if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
+ input_element_type)) {
+ input_element_type = quantType.getStorageType();
+ }
+
+ if (input_element_type.isa<mlir::FloatType>()) {
+ min_fp =
+ mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat();
+ max_fp =
+ mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat();
+ } else if (input_element_type.isUnsignedInteger()) {
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt();
+ } else {
+ assert(input_element_type.isa<mlir::IntegerType>());
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt();
+ }
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));