aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp39
1 files changed, 34 insertions, 5 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 9660833..d69f005 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -937,13 +937,42 @@ TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const {
TosaClampAttribute *attr =
static_cast<TosaClampAttribute *>(op->GetAttribute());
- auto min_int = op_builder->getI64IntegerAttr(attr->min_int());
- auto max_int = op_builder->getI64IntegerAttr(attr->max_int());
- auto min_fp = op_builder->getF32FloatAttr(attr->min_fp());
- auto max_fp = op_builder->getF32FloatAttr(attr->max_fp());
+ mlir::Type input_element_type =
+ llvm::cast<mlir::ShapedType>(input_val.getType()).getElementType();
+ if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
+ input_element_type)) {
+ input_element_type = quantType.getStorageType();
+ }
+
+ mlir::Attribute min_val_attr, max_val_attr;
+ if (input_element_type.isa<mlir::FloatType>()) {
+ min_val_attr = op_builder->getFloatAttr(input_element_type, attr->min_fp());
+ max_val_attr = op_builder->getFloatAttr(input_element_type, attr->max_fp());
+ } else if (input_element_type.isUnsignedInteger()) {
+ if (input_element_type.isUnsignedInteger(8)) {
+ uint8_t min_val = attr->min_int();
+ uint8_t max_val = attr->max_int();
+ min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
+ max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
+ } else if (input_element_type.isUnsignedInteger(16)) {
+ uint16_t min_val = attr->min_int();
+ uint16_t max_val = attr->max_int();
+ min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
+ max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
+ } else {
+ llvm::errs() << "ERROR: " << get_string(op)
+ << " contains unsupported unsigned int element data type.\n";
+ return {};
+ }
+ } else {
+ min_val_attr =
+ op_builder->getIntegerAttr(input_element_type, attr->min_int());
+ max_val_attr =
+ op_builder->getIntegerAttr(input_element_type, attr->max_int());
+ }
mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ClampOp>(
- loc, output_type, input_val, min_int, max_int, min_fp, max_fp);
+ loc, output_type, input_val, min_val_attr, max_val_attr);
block->push_back(mlir_op);
return std::vector<mlir::Value>({mlir_op->getResult(0)});
}