diff options
author | Tai Ly <tai.ly@arm.com> | 2024-03-08 18:32:46 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2024-03-17 19:55:42 -0700 |
commit | 909d4d159ee12c6bc8113974d76f46249b6fd7fb (patch) | |
tree | 7e6320d8f74ba6478a654404ccc74cca2ff3219f /src/TosaDeserialize.cpp | |
parent | f983e51df5030facfd1c5bf59dcc67a32a1913a8 (diff) | |
download | tosa_mlir_translator-909d4d159ee12c6bc8113974d76f46249b6fd7fb.tar.gz |
[tosa_mlir_translator] Use new Clamp and Pad fbs attributes
This implements changes required for new Tosa Flatbuffer
schema where Clamp and Pad attributes have changed to use
ubyte arrays to store int or float values.
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I2aa2025422fda4aacaf6d80727060a01a30cee89
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r-- | src/TosaDeserialize.cpp | 78 |
1 files changed, 53 insertions, 25 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index d69f005..b12d22a 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -946,29 +946,44 @@ TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const { mlir::Attribute min_val_attr, max_val_attr; if (input_element_type.isa<mlir::FloatType>()) { - min_val_attr = op_builder->getFloatAttr(input_element_type, attr->min_fp()); - max_val_attr = op_builder->getFloatAttr(input_element_type, attr->max_fp()); - } else if (input_element_type.isUnsignedInteger()) { - if (input_element_type.isUnsignedInteger(8)) { - uint8_t min_val = attr->min_int(); - uint8_t max_val = attr->max_int(); - min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); - max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); - } else if (input_element_type.isUnsignedInteger(16)) { - uint16_t min_val = attr->min_int(); - uint16_t max_val = attr->max_int(); - min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); - max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); - } else { - llvm::errs() << "ERROR: " << get_string(op) - << " contains unsupported unsigned int element data type.\n"; - return {}; - } - } else { + std::vector<float> min_float_data, max_float_data; + TosaSerializationHandler::ConvertU8toF32(attr->min_val(), /* size = */ 1, + min_float_data); + TosaSerializationHandler::ConvertU8toF32(attr->max_val(), /* size = */ 1, + max_float_data); min_val_attr = - op_builder->getIntegerAttr(input_element_type, attr->min_int()); + op_builder->getFloatAttr(input_element_type, min_float_data[0]); max_val_attr = - op_builder->getIntegerAttr(input_element_type, attr->max_int()); + op_builder->getFloatAttr(input_element_type, max_float_data[0]); + } else { + std::vector<int32_t> min_int_data, max_int_data; + TosaSerializationHandler::ConvertU8toI32(attr->min_val(), /* size = */ 1, + min_int_data); + TosaSerializationHandler::ConvertU8toI32(attr->max_val(), /* size = */ 1, + max_int_data); + if (input_element_type.isUnsignedInteger()) { + if (input_element_type.isUnsignedInteger(8)) { + uint8_t min_val = min_int_data[0]; + uint8_t max_val = max_int_data[0]; + min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); + max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); + } else if (input_element_type.isUnsignedInteger(16)) { + uint16_t min_val = min_int_data[0]; + uint16_t max_val = max_int_data[0]; + min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val); + max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val); + } else { + llvm::errs() + << "ERROR: " << get_string(op) + << " contains unsupported unsigned int element data type.\n"; + return {}; + } + } else { + min_val_attr = + op_builder->getIntegerAttr(input_element_type, min_int_data[0]); + max_val_attr = + op_builder->getIntegerAttr(input_element_type, max_int_data[0]); + } } mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ClampOp>( @@ -1075,19 +1090,32 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const { tensor_type_map->at(op->GetInputTensorNames()[0]); mlir::RankedTensorType output_type = tensor_type_map->at(op->GetOutputTensorNames()[0]); + const auto element_type = + input_val.getType().cast<mlir::ShapedType>().getElementType(); assert(op->GetAttributeType() == Attribute_PadAttribute); // double check attribute type TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute()); - auto pad_const_int = attr->pad_const_int(); - auto pad_const_fp = attr->pad_const_fp(); + float pad_const_fp = 0.0f; + int32_t pad_const_int = 0; + + if (element_type.isa<mlir::FloatType>()) { + std::vector<float> float_data; + TosaSerializationHandler::ConvertU8toF32(attr->pad_const(), + /* size = */ 1, float_data); + pad_const_fp = float_data[0]; + } else { + std::vector<int32_t> int32_data; + TosaSerializationHandler::ConvertU8toI32(attr->pad_const(), + /* size = */ 1, int32_data); + pad_const_int = int32_data[0]; + } + // todo: int input_zp = attr->pad_input_zp(); mlir::Operation *mlir_op; mlir::Value pad_const_value; - const auto element_type = - input_val.getType().cast<mlir::ShapedType>().getElementType(); bool isBoolType = element_type.isInteger(1); // First handle boolean type. |