aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-08 18:32:46 +0000
committerTai Ly <tai.ly@arm.com>2024-03-17 19:55:42 -0700
commit909d4d159ee12c6bc8113974d76f46249b6fd7fb (patch)
tree7e6320d8f74ba6478a654404ccc74cca2ff3219f /src
parentf983e51df5030facfd1c5bf59dcc67a32a1913a8 (diff)
downloadtosa_mlir_translator-909d4d159ee12c6bc8113974d76f46249b6fd7fb.tar.gz
[tosa_mlir_translator] Use new Clamp and Pad fbs attributes
This implements changes required for new Tosa Flatbuffer schema where Clamp and Pad attributes have changed to use ubyte arrays to store int or float values. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I2aa2025422fda4aacaf6d80727060a01a30cee89
Diffstat (limited to 'src')
-rw-r--r--src/TosaDeserialize.cpp78
-rw-r--r--src/TosaSerialize.cpp47
2 files changed, 82 insertions, 43 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index d69f005..b12d22a 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -946,29 +946,44 @@ TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const {
mlir::Attribute min_val_attr, max_val_attr;
if (input_element_type.isa<mlir::FloatType>()) {
- min_val_attr = op_builder->getFloatAttr(input_element_type, attr->min_fp());
- max_val_attr = op_builder->getFloatAttr(input_element_type, attr->max_fp());
- } else if (input_element_type.isUnsignedInteger()) {
- if (input_element_type.isUnsignedInteger(8)) {
- uint8_t min_val = attr->min_int();
- uint8_t max_val = attr->max_int();
- min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
- max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
- } else if (input_element_type.isUnsignedInteger(16)) {
- uint16_t min_val = attr->min_int();
- uint16_t max_val = attr->max_int();
- min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
- max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
- } else {
- llvm::errs() << "ERROR: " << get_string(op)
- << " contains unsupported unsigned int element data type.\n";
- return {};
- }
- } else {
+ std::vector<float> min_float_data, max_float_data;
+ TosaSerializationHandler::ConvertU8toF32(attr->min_val(), /* size = */ 1,
+ min_float_data);
+ TosaSerializationHandler::ConvertU8toF32(attr->max_val(), /* size = */ 1,
+ max_float_data);
min_val_attr =
- op_builder->getIntegerAttr(input_element_type, attr->min_int());
+ op_builder->getFloatAttr(input_element_type, min_float_data[0]);
max_val_attr =
- op_builder->getIntegerAttr(input_element_type, attr->max_int());
+ op_builder->getFloatAttr(input_element_type, max_float_data[0]);
+ } else {
+ std::vector<int32_t> min_int_data, max_int_data;
+ TosaSerializationHandler::ConvertU8toI32(attr->min_val(), /* size = */ 1,
+ min_int_data);
+ TosaSerializationHandler::ConvertU8toI32(attr->max_val(), /* size = */ 1,
+ max_int_data);
+ if (input_element_type.isUnsignedInteger()) {
+ if (input_element_type.isUnsignedInteger(8)) {
+ uint8_t min_val = min_int_data[0];
+ uint8_t max_val = max_int_data[0];
+ min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
+ max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
+ } else if (input_element_type.isUnsignedInteger(16)) {
+ uint16_t min_val = min_int_data[0];
+ uint16_t max_val = max_int_data[0];
+ min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
+ max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
+ } else {
+ llvm::errs()
+ << "ERROR: " << get_string(op)
+ << " contains unsupported unsigned int element data type.\n";
+ return {};
+ }
+ } else {
+ min_val_attr =
+ op_builder->getIntegerAttr(input_element_type, min_int_data[0]);
+ max_val_attr =
+ op_builder->getIntegerAttr(input_element_type, max_int_data[0]);
+ }
}
mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ClampOp>(
@@ -1075,19 +1090,32 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
tensor_type_map->at(op->GetInputTensorNames()[0]);
mlir::RankedTensorType output_type =
tensor_type_map->at(op->GetOutputTensorNames()[0]);
+ const auto element_type =
+ input_val.getType().cast<mlir::ShapedType>().getElementType();
assert(op->GetAttributeType() ==
Attribute_PadAttribute); // double check attribute type
TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute());
- auto pad_const_int = attr->pad_const_int();
- auto pad_const_fp = attr->pad_const_fp();
+ float pad_const_fp = 0.0f;
+ int32_t pad_const_int = 0;
+
+ if (element_type.isa<mlir::FloatType>()) {
+ std::vector<float> float_data;
+ TosaSerializationHandler::ConvertU8toF32(attr->pad_const(),
+ /* size = */ 1, float_data);
+ pad_const_fp = float_data[0];
+ } else {
+ std::vector<int32_t> int32_data;
+ TosaSerializationHandler::ConvertU8toI32(attr->pad_const(),
+ /* size = */ 1, int32_data);
+ pad_const_int = int32_data[0];
+ }
+
// todo: int input_zp = attr->pad_input_zp();
mlir::Operation *mlir_op;
mlir::Value pad_const_value;
- const auto element_type =
- input_val.getType().cast<mlir::ShapedType>().getElementType();
bool isBoolType = element_type.isInteger(1);
// First handle boolean type.
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 54a1d28..55a11fd 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -954,10 +954,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>(
mlir::Operation &op) const {
auto min_val_attr = op.getAttr("min_val");
auto max_val_attr = op.getAttr("max_val");
- float min_fp = 0;
- float max_fp = 0;
- int32_t min_int = 0;
- int32_t max_int = 0;
mlir::Type input_element_type =
llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
@@ -966,24 +962,33 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>(
input_element_type = quantType.getStorageType();
}
+ std::vector<uint8_t> min_val, max_val;
if (input_element_type.isa<mlir::FloatType>()) {
- min_fp =
+ auto min_fp =
mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat();
- max_fp =
+ auto max_fp =
mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat();
- } else if (input_element_type.isUnsignedInteger()) {
- min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt();
- max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt();
+ TosaSerializationHandler::ConvertF32toU8({min_fp}, min_val);
+ TosaSerializationHandler::ConvertF32toU8({max_fp}, max_val);
} else {
- assert(input_element_type.isa<mlir::IntegerType>());
- min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt();
- max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt();
+ int32_t min_int = 0;
+ int32_t max_int = 0;
+ if (input_element_type.isUnsignedInteger()) {
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt();
+ } else {
+ assert(input_element_type.isa<mlir::IntegerType>());
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt();
+ }
+ TosaSerializationHandler::ConvertI32toU8({min_int}, min_val);
+ TosaSerializationHandler::ConvertI32toU8({max_int}, max_val);
}
std::string input_name = GetTensorName(op.getOperand(0));
std::string output_name = GetTensorName(op.getResult(0));
- TosaClampAttribute attribute(min_int, max_int, min_fp, max_fp);
+ TosaClampAttribute attribute(min_val, max_val);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_CLAMP, Attribute_ClampAttribute, &attribute,
@@ -1128,7 +1133,15 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
}
}
- TosaPadAttribute attribute({}, pad_const_int, pad_const_fp);
+ std::vector<uint8_t> pad_const;
+ mlir::Type input_element_type =
+ llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
+ if (input_element_type.isa<mlir::FloatType>()) {
+ TosaSerializationHandler::ConvertF32toU8({pad_const_fp}, pad_const);
+ } else {
+ TosaSerializationHandler::ConvertI32toU8({pad_const_int}, pad_const);
+ }
+ TosaPadAttribute attribute(pad_const);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_PAD, Attribute_PadAttribute, &attribute,
@@ -1386,10 +1399,8 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::RescaleOp>(
std::string shift_name = GetTensorName(op.getOperand(2));
std::string output_name = GetTensorName(output);
- TosaRescaleAttribute attribute(input_zp, output_zp,
- /* multiplier = */ {}, /* shift = */ {},
- scale32, double_round, per_channel,
- input_unsigned, output_unsigned);
+ TosaRescaleAttribute attribute(input_zp, output_zp, scale32, double_round,
+ per_channel, input_unsigned, output_unsigned);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
Op_RESCALE, Attribute_RescaleAttribute, &attribute,