aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-08 18:32:46 +0000
committerTai Ly <tai.ly@arm.com>2024-03-17 19:55:42 -0700
commit909d4d159ee12c6bc8113974d76f46249b6fd7fb (patch)
tree7e6320d8f74ba6478a654404ccc74cca2ff3219f /src/TosaDeserialize.cpp
parentf983e51df5030facfd1c5bf59dcc67a32a1913a8 (diff)
downloadtosa_mlir_translator-909d4d159ee12c6bc8113974d76f46249b6fd7fb.tar.gz
[tosa_mlir_translator] Use new Clamp and Pad fbs attributes
This implements changes required for new Tosa Flatbuffer schema where Clamp and Pad attributes have changed to use ubyte arrays to store int or float values. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I2aa2025422fda4aacaf6d80727060a01a30cee89
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp78
1 files changed, 53 insertions, 25 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index d69f005..b12d22a 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -946,29 +946,44 @@ TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const {
mlir::Attribute min_val_attr, max_val_attr;
if (input_element_type.isa<mlir::FloatType>()) {
- min_val_attr = op_builder->getFloatAttr(input_element_type, attr->min_fp());
- max_val_attr = op_builder->getFloatAttr(input_element_type, attr->max_fp());
- } else if (input_element_type.isUnsignedInteger()) {
- if (input_element_type.isUnsignedInteger(8)) {
- uint8_t min_val = attr->min_int();
- uint8_t max_val = attr->max_int();
- min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
- max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
- } else if (input_element_type.isUnsignedInteger(16)) {
- uint16_t min_val = attr->min_int();
- uint16_t max_val = attr->max_int();
- min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
- max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
- } else {
- llvm::errs() << "ERROR: " << get_string(op)
- << " contains unsupported unsigned int element data type.\n";
- return {};
- }
- } else {
+ std::vector<float> min_float_data, max_float_data;
+ TosaSerializationHandler::ConvertU8toF32(attr->min_val(), /* size = */ 1,
+ min_float_data);
+ TosaSerializationHandler::ConvertU8toF32(attr->max_val(), /* size = */ 1,
+ max_float_data);
min_val_attr =
- op_builder->getIntegerAttr(input_element_type, attr->min_int());
+ op_builder->getFloatAttr(input_element_type, min_float_data[0]);
max_val_attr =
- op_builder->getIntegerAttr(input_element_type, attr->max_int());
+ op_builder->getFloatAttr(input_element_type, max_float_data[0]);
+ } else {
+ std::vector<int32_t> min_int_data, max_int_data;
+ TosaSerializationHandler::ConvertU8toI32(attr->min_val(), /* size = */ 1,
+ min_int_data);
+ TosaSerializationHandler::ConvertU8toI32(attr->max_val(), /* size = */ 1,
+ max_int_data);
+ if (input_element_type.isUnsignedInteger()) {
+ if (input_element_type.isUnsignedInteger(8)) {
+ uint8_t min_val = min_int_data[0];
+ uint8_t max_val = max_int_data[0];
+ min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
+ max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
+ } else if (input_element_type.isUnsignedInteger(16)) {
+ uint16_t min_val = min_int_data[0];
+ uint16_t max_val = max_int_data[0];
+ min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
+ max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
+ } else {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " contains unsupported unsigned int element data type.\n";
+ return {};
+ }
+ } else {
+ min_val_attr =
+ op_builder->getIntegerAttr(input_element_type, min_int_data[0]);
+ max_val_attr =
+ op_builder->getIntegerAttr(input_element_type, max_int_data[0]);
+ }
}
mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ClampOp>(
@@ -1075,19 +1090,32 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
tensor_type_map->at(op->GetInputTensorNames()[0]);
mlir::RankedTensorType output_type =
tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ const auto element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
assert(op->GetAttributeType() ==
Attribute_PadAttribute); // double check attribute type
TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute());
- auto pad_const_int = attr->pad_const_int();
- auto pad_const_fp = attr->pad_const_fp();
+ float pad_const_fp = 0.0f;
+ int32_t pad_const_int = 0;
+
+ if (element_type.isa<mlir::FloatType>()) {
+ std::vector<float> float_data;
+ TosaSerializationHandler::ConvertU8toF32(attr->pad_const(),
+ /* size = */ 1, float_data);
+ pad_const_fp = float_data[0];
+ } else {
+ std::vector<int32_t> int32_data;
+ TosaSerializationHandler::ConvertU8toI32(attr->pad_const(),
+ /* size = */ 1, int32_data);
+ pad_const_int = int32_data[0];
+ }
+
// todo: int input_zp = attr->pad_input_zp();
mlir::Operation *mlir_op;
mlir::Value pad_const_value;
- const auto element_type =
- input_val.getType().cast<mlir::ShapedType>().getElementType();
bool isBoolType = element_type.isInteger(1);
// First handle boolean type.