aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-08 18:32:46 +0000
committerTai Ly <tai.ly@arm.com>2024-03-17 19:55:42 -0700
commit909d4d159ee12c6bc8113974d76f46249b6fd7fb (patch)
tree7e6320d8f74ba6478a654404ccc74cca2ff3219f
parentf983e51df5030facfd1c5bf59dcc67a32a1913a8 (diff)
downloadtosa_mlir_translator-909d4d159ee12c6bc8113974d76f46249b6fd7fb.tar.gz
[tosa_mlir_translator] Use new Clamp and Pad fbs attributes
This implements changes required for new Tosa Flatbuffer schema where Clamp and Pad attributes have changed to use ubyte arrays to store int or float values. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I2aa2025422fda4aacaf6d80727060a01a30cee89
-rw-r--r--src/TosaDeserialize.cpp78
-rw-r--r--src/TosaSerialize.cpp47
m---------third_party/serialization_lib0
3 files changed, 82 insertions, 43 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index d69f005..b12d22a 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -946,29 +946,44 @@ TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const {
mlir::Attribute min_val_attr, max_val_attr;
if (input_element_type.isa<mlir::FloatType>()) {
- min_val_attr = op_builder->getFloatAttr(input_element_type, attr->min_fp());
- max_val_attr = op_builder->getFloatAttr(input_element_type, attr->max_fp());
- } else if (input_element_type.isUnsignedInteger()) {
- if (input_element_type.isUnsignedInteger(8)) {
- uint8_t min_val = attr->min_int();
- uint8_t max_val = attr->max_int();
- min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
- max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
- } else if (input_element_type.isUnsignedInteger(16)) {
- uint16_t min_val = attr->min_int();
- uint16_t max_val = attr->max_int();
- min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
- max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
- } else {
- llvm::errs() << "ERROR: " << get_string(op)
- << " contains unsupported unsigned int element data type.\n";
- return {};
- }
- } else {
+ std::vector<float> min_float_data, max_float_data;
+ TosaSerializationHandler::ConvertU8toF32(attr->min_val(), /* size = */ 1,
+ min_float_data);
+ TosaSerializationHandler::ConvertU8toF32(attr->max_val(), /* size = */ 1,
+ max_float_data);
min_val_attr =
- op_builder->getIntegerAttr(input_element_type, attr->min_int());
+ op_builder->getFloatAttr(input_element_type, min_float_data[0]);
max_val_attr =
- op_builder->getIntegerAttr(input_element_type, attr->max_int());
+ op_builder->getFloatAttr(input_element_type, max_float_data[0]);
+ } else {
+ std::vector<int32_t> min_int_data, max_int_data;
+ TosaSerializationHandler::ConvertU8toI32(attr->min_val(), /* size = */ 1,
+ min_int_data);
+ TosaSerializationHandler::ConvertU8toI32(attr->max_val(), /* size = */ 1,
+ max_int_data);
+ if (input_element_type.isUnsignedInteger()) {
+ if (input_element_type.isUnsignedInteger(8)) {
+ uint8_t min_val = min_int_data[0];
+ uint8_t max_val = max_int_data[0];
+ min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
+ max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
+ } else if (input_element_type.isUnsignedInteger(16)) {
+ uint16_t min_val = min_int_data[0];
+ uint16_t max_val = max_int_data[0];
+ min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
+ max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
+ } else {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " contains unsupported unsigned int element data type.\n";
+ return {};
+ }
+ } else {
+ min_val_attr =
+ op_builder->getIntegerAttr(input_element_type, min_int_data[0]);
+ max_val_attr =
+ op_builder->getIntegerAttr(input_element_type, max_int_data[0]);
+ }
}
mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ClampOp>(
@@ -1075,19 +1090,32 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
tensor_type_map->at(op->GetInputTensorNames()[0]);
mlir::RankedTensorType output_type =
tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ const auto element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
assert(op->GetAttributeType() ==
Attribute_PadAttribute); // double check attribute type
TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute());
- auto pad_const_int = attr->pad_const_int();
- auto pad_const_fp = attr->pad_const_fp();
+ float pad_const_fp = 0.0f;
+ int32_t pad_const_int = 0;
+
+ if (element_type.isa<mlir::FloatType>()) {
+ std::vector<float> float_data;
+ TosaSerializationHandler::ConvertU8toF32(attr->pad_const(),
+ /* size = */ 1, float_data);
+ pad_const_fp = float_data[0];
+ } else {
+ std::vector<int32_t> int32_data;
+ TosaSerializationHandler::ConvertU8toI32(attr->pad_const(),
+ /* size = */ 1, int32_data);
+ pad_const_int = int32_data[0];
+ }
+
// todo: int input_zp = attr->pad_input_zp();
mlir::Operation *mlir_op;
mlir::Value pad_const_value;
- const auto element_type =
- input_val.getType().cast<mlir::ShapedType>().getElementType();
bool isBoolType = element_type.isInteger(1);
// First handle boolean type.
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 54a1d28..55a11fd 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -954,10 +954,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>(
mlir::Operation &op) const {
auto min_val_attr = op.getAttr("min_val");
auto max_val_attr = op.getAttr("max_val");
- float min_fp = 0;
- float max_fp = 0;
- int32_t min_int = 0;
- int32_t max_int = 0;
mlir::Type input_element_type =
llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
@@ -966,24 +962,33 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>(
input_element_type = quantType.getStorageType();
}
+ std::vector<uint8_t> min_val, max_val;
if (input_element_type.isa<mlir::FloatType>()) {
- min_fp =
+ auto min_fp =
mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat();
- max_fp =
+ auto max_fp =
mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat();
- } else if (input_element_type.isUnsignedInteger()) {
- min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt();
- max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt();
+ TosaSerializationHandler::ConvertF32toU8({min_fp}, min_val);
+ TosaSerializationHandler::ConvertF32toU8({max_fp}, max_val);
} else {
- assert(input_element_type.isa<mlir::IntegerType>());
- min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt();
- max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt();
+ int32_t min_int = 0;
+ int32_t max_int = 0;
+ if (input_element_type.isUnsignedInteger()) {
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt();
+ } else {
+ assert(input_element_type.isa<mlir::IntegerType>());
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt();
+ }
+ TosaSerializationHandler::ConvertI32toU8({min_int}, min_val);
+ TosaSerializationHandler::ConvertI32toU8({max_int}, max_val);
}
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
- TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp);
+ TosaClampAttribute attribute(min_val, max_val);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CLAMP, Attribute_ClampAttribute, &attribute,
@@ -1128,7 +1133,15 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
}
}
- TosaPadAttribute attribute({}, pad_const_int, pad_const_fp);
+ std::vector<uint8_t> pad_const;
+ mlir::Type input_element_type =
+ llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
+ if (input_element_type.isa<mlir::FloatType>()) {
+ TosaSerializationHandler::ConvertF32toU8({pad_const_fp}, pad_const);
+ } else {
+ TosaSerializationHandler::ConvertI32toU8({pad_const_int}, pad_const);
+ }
+ TosaPadAttribute attribute(pad_const);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_PAD, Attribute_PadAttribute, &attribute,
@@ -1386,10 +1399,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>(
std::string shift_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(output);
- TosaRescaleAttribute attribute(input_zp, output_zp,
- /* multiplier = */ {}, /* shift = */ {},
- scale32, double_round, per_channel,
- input_unsigned, output_unsigned);
+ TosaRescaleAttribute attribute(input_zp, output_zp, scale32, double_round,
+ per_channel, input_unsigned, output_unsigned);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_RESCALE, Attribute_RescaleAttribute, &attribute,
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject 758e73e117c5cef17f8f0b1c543efc1df953b2f
+Subproject 0b6d7c271af1e6593e6a2cf14b32acea765f4b6