diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/TosaDeserialize.cpp | 39 | ||||
-rw-r--r-- | src/TosaSerialize.cpp | 39 |
2 files changed, 61 insertions, 17 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 9660833..d69f005 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -937,13 +937,42 @@ TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const { TosaClampAttribute *attr = static_cast<TosaClampAttribute *>(op->GetAttribute()); - auto min_int = op_builder->getI64IntegerAttr(attr->min_int()); - auto max_int = op_builder->getI64IntegerAttr(attr->max_int()); - auto min_fp = op_builder->getF32FloatAttr(attr->min_fp()); - auto max_fp = op_builder->getF32FloatAttr(attr->max_fp()); + mlir::Type input_element_type = + llvm::cast<mlir::ShapedType>(input_val.getType()).getElementType(); + if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>( + input_element_type)) { + input_element_type = quantType.getStorageType(); + } + + mlir::Attribute min_val_attr, max_val_attr; + if (input_element_type.isa<mlir::FloatType>()) { + min_val_attr = op_builder->getFloatAttr(input_element_type, attr->min_fp()); + max_val_attr = op_builder->getFloatAttr(input_element_type, attr->max_fp()); + } else if (input_element_type.isUnsignedInteger()) { + if (input_element_type.isUnsignedInteger(8)) { + uint8_t min_val = attr->min_int(); + uint8_t max_val = attr->max_int(); + min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); + max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); + } else if (input_element_type.isUnsignedInteger(16)) { + uint16_t min_val = attr->min_int(); + uint16_t max_val = attr->max_int(); + min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); + max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); + } else { + llvm::errs() << "ERROR: " << get_string(op) + << " contains unsupported unsigned int element data type.\n"; + return {}; + } + } else { + min_val_attr = + op_builder->getIntegerAttr(input_element_type, attr->min_int()); + max_val_attr = + op_builder->getIntegerAttr(input_element_type, attr->max_int()); + } mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ClampOp>( - loc, output_type, input_val, min_int, max_int, min_fp, max_fp); + loc, output_type, input_val, min_val_attr, max_val_attr); block->push_back(mlir_op); return std::vector<mlir::Value>({mlir_op->getResult(0)}); } diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 74988c8..54a1d28 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -952,18 +952,33 @@ template <> TosaSerializationOperator * TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>( mlir::Operation &op) const { - int32_t min_int = - op.getAttr("min_int").dyn_cast<mlir::IntegerAttr>().getInt(); - int32_t max_int = - op.getAttr("max_int").dyn_cast<mlir::IntegerAttr>().getInt(); - float min_fp = op.getAttr("min_fp") - .dyn_cast<mlir::FloatAttr>() - .getValue() - .convertToFloat(); - float max_fp = op.getAttr("max_fp") - .dyn_cast<mlir::FloatAttr>() - .getValue() - .convertToFloat(); + auto min_val_attr = op.getAttr("min_val"); + auto max_val_attr = op.getAttr("max_val"); + float min_fp = 0; + float max_fp = 0; + int32_t min_int = 0; + int32_t max_int = 0; + + mlir::Type input_element_type = + llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType(); + if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>( + input_element_type)) { + input_element_type = quantType.getStorageType(); + } + + if (input_element_type.isa<mlir::FloatType>()) { + min_fp = + mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat(); + max_fp = + mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat(); + } else if (input_element_type.isUnsignedInteger()) { + min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt(); + max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt(); + } else { + assert(input_element_type.isa<mlir::IntegerType>()); + min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt(); + max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt(); + } std::string input_name = GetTensorName(op.getOperand(0)); std::string output_name = GetTensorName(op.getResult(0)); |