aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-13 19:19:53 +0000
committerTai Ly <tai.ly@arm.com>2024-04-15 14:28:55 +0000
commit6103155b3a2a555c3fc4a3a2173b35ea573c9600 (patch)
tree76c353422dfc1e7473172d9c9f00b1d5c636c2e8
parent5eddcd35c1776784baeeb39e92bad81da826e065 (diff)
downloadtosa_mlir_translator-6103155b3a2a555c3fc4a3a2173b35ea573c9600.tar.gz
[tosa_mlir_translator] Fix fp16, bf16 and fp8 serialization
Fix serialization and deserialization of fp16, bf16 and fp8 for pad_const, clamp min_val/max_val, and const values Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ia39a17d2f395584d5555d2c86cdae7113cf14e3f
-rw-r--r--src/TosaDeserialize.cpp277
-rw-r--r--src/TosaSerialize.cpp270
m---------third_party/serialization_lib0
3 files changed, 269 insertions, 278 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 6fa691e..fdbd892 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -166,58 +166,70 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
return mlir::success();
}
-mlir::DenseElementsAttr
-ConstructConstAttr(const mlir::RankedTensorType &output_type,
- TosaSerializationTensor *ts, const std::string &op_name) {
- const auto &data = ts->GetData();
- auto &shape = ts->GetShape();
- // compute output data size
- uint32_t out_size = 1;
- for (const auto dim : shape) {
- out_size *= dim;
- }
- mlir::DenseElementsAttr value_attr;
- switch (ts->GetDtype()) {
- case DType_FP32:
- case DType_BF16:
- case DType_FP8E4M3:
- case DType_FP8E5M2: {
- // for FP32, FP16 and FP8 types, value attributes are stored as FP32 values
+mlir::DenseElementsAttr GetConstAttr(const std::vector<uint8_t> &data,
+ const mlir::RankedTensorType &output_type,
+ uint32_t out_size) {
+ auto element_type = output_type.getElementType();
+ if (element_type.isF32()) {
+ // for FP32, value attributes are stored as FP32 values
std::vector<float> float_data;
TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data);
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(float_data));
- break;
+ return mlir::DenseElementsAttr::get(output_type,
+ llvm::ArrayRef(float_data));
+ }
+ if (element_type.isBF16()) {
+ mlir::SmallVector<mlir::APFloat> bf16_data;
+ for (uint32_t i = 0; i < out_size; i++) {
+ uint64_t byte0 = data[i * sizeof(int16_t)];
+ uint64_t byte1 = data[i * sizeof(int16_t) + 1];
+ uint64_t bits = byte0 + (byte1 << 8);
+ mlir::APInt bf16_bits(16, bits);
+ mlir::APFloat bf16(mlir::APFloat::BFloat(), bf16_bits);
+ bf16_data.push_back(bf16);
+ }
+ return mlir::DenseElementsAttr::get(output_type, bf16_data);
+ }
+ if (element_type.isFloat8E4M3FN()) {
+ mlir::SmallVector<mlir::APFloat> f8_data;
+ for (uint32_t i = 0; i < out_size; i++) {
+ mlir::APInt f8_bits(8, static_cast<uint64_t>(data[i]));
+ mlir::APFloat f8(mlir::APFloat::Float8E4M3FN(), f8_bits);
+ f8_data.push_back(f8);
+ }
+ return mlir::DenseElementsAttr::get(output_type, f8_data);
+ }
+ if (element_type.isFloat8E5M2()) {
+ mlir::SmallVector<mlir::APFloat> f8_data;
+ for (uint32_t i = 0; i < out_size; i++) {
+ mlir::APInt f8_bits(8, static_cast<uint64_t>(data[i]));
+ mlir::APFloat f8(mlir::APFloat::Float8E5M2(), f8_bits);
+ f8_data.push_back(f8);
+ }
+ return mlir::DenseElementsAttr::get(output_type, f8_data);
}
- case DType_INT4: {
+ if (element_type.isInteger(4)) {
std::vector<int8_t> int4_data;
TosaSerializationHandler::ConvertU8toI4(data, out_size, int4_data);
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data));
- break;
+ return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int4_data));
}
- case DType_INT8: {
+ if (element_type.isInteger(8)) {
std::vector<int8_t> int8_data;
TosaSerializationHandler::ConvertU8toI8(data, out_size, int8_data);
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data));
- break;
+ return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int8_data));
}
- case DType_INT16: {
+ if (element_type.isInteger(16)) {
std::vector<int16_t> int16_data;
TosaSerializationHandler::ConvertU8toI16(data, out_size, int16_data);
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int16_data));
- break;
+ return mlir::DenseElementsAttr::get(output_type,
+ llvm::ArrayRef(int16_data));
}
- case DType_INT32: {
+ if (element_type.isInteger(32)) {
std::vector<int32_t> int32_data;
TosaSerializationHandler::ConvertU8toI32(data, out_size, int32_data);
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(int32_data));
- break;
+ return mlir::DenseElementsAttr::get(output_type,
+ llvm::ArrayRef(int32_data));
}
- case DType_INT48: {
+ if (element_type.isInteger(48)) {
std::vector<int64_t> int48_data;
TosaSerializationHandler::ConvertU8toI48(data, out_size, int48_data);
std::vector<mlir::APInt> apint_data;
@@ -226,34 +238,38 @@ ConstructConstAttr(const mlir::RankedTensorType &output_type,
/* isSigned = */ false);
apint_data.push_back(apint_value);
}
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(apint_data));
- break;
+ return mlir::DenseElementsAttr::get(output_type,
+ llvm::ArrayRef(apint_data));
}
- case DType_BOOL: {
+ if (element_type.isInteger(1)) {
std::vector<bool> bool_data;
TosaSerializationHandler::ConvertU8toBool(data, out_size, bool_data);
llvm::SmallVector<bool> bool_values(bool_data.begin(), bool_data.end());
- value_attr = mlir::DenseElementsAttr::get(output_type, bool_values);
- break;
+ return mlir::DenseElementsAttr::get(output_type, bool_values);
}
- case DType_FP16: {
+ if (element_type.isF16()) {
std::vector<half_float::half> half_data;
TosaSerializationHandler::ConvertU8toF16(data, out_size, half_data);
- value_attr =
- mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(half_data));
- break;
+ return mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(half_data));
}
- case DType_UINT8:
- case DType_UINT16:
- default: {
+
+ return nullptr;
+}
+
+mlir::DenseElementsAttr
+ConstructConstAttr(const mlir::RankedTensorType &output_type,
+ TosaSerializationTensor *ts, const std::string &op_name) {
+ // compute output data size
+ uint32_t out_size = 1;
+ for (const auto dim : ts->GetShape()) {
+ out_size *= dim;
+ }
+ auto attr = GetConstAttr(ts->GetData(), output_type, out_size);
+ if (!attr) {
llvm::errs() << "ERROR: " << op_name
<< " contains unsupported element type\n";
- return nullptr;
}
- }
-
- return value_attr;
+ return attr;
}
mlir::LogicalResult ConstructVariableOps(mlir::ModuleOp &module) {
@@ -942,56 +958,31 @@ TosaMlirOperatorBuilder::build<Op_CLAMP>(TosaSerializationOperator *op) const {
TosaClampAttribute *attr =
static_cast<TosaClampAttribute *>(op->GetAttribute());
- mlir::Type input_element_type =
+ mlir::Type element_type =
llvm::cast<mlir::ShapedType>(input_val.getType()).getElementType();
- if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
- input_element_type)) {
- input_element_type = quantType.getStorageType();
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(element_type)) {
+ element_type = quantType.getStorageType();
}
+ auto element_const_type = mlir::RankedTensorType::get({1}, element_type);
+ auto min_values_attr = GetConstAttr(attr->min_val(), element_const_type, 1);
+ auto max_values_attr = GetConstAttr(attr->max_val(), element_const_type, 1);
+
mlir::Attribute min_val_attr, max_val_attr;
- if (input_element_type.isa<mlir::FloatType>()) {
- std::vector<float> min_float_data, max_float_data;
- TosaSerializationHandler::ConvertU8toF32(attr->min_val(), /* size = */ 1,
- min_float_data);
- TosaSerializationHandler::ConvertU8toF32(attr->max_val(), /* size = */ 1,
- max_float_data);
- min_val_attr =
- op_builder->getFloatAttr(input_element_type, min_float_data[0]);
- max_val_attr =
- op_builder->getFloatAttr(input_element_type, max_float_data[0]);
+ if (element_type.isa<mlir::FloatType>()) {
+ min_val_attr = op_builder->getFloatAttr(
+ element_type, min_values_attr.getValues<mlir::APFloat>()[0]);
+ max_val_attr = op_builder->getFloatAttr(
+ element_type, max_values_attr.getValues<mlir::APFloat>()[0]);
} else {
- std::vector<int32_t> min_int_data, max_int_data;
- TosaSerializationHandler::ConvertU8toI32(attr->min_val(), /* size = */ 1,
- min_int_data);
- TosaSerializationHandler::ConvertU8toI32(attr->max_val(), /* size = */ 1,
- max_int_data);
- if (input_element_type.isUnsignedInteger()) {
- if (input_element_type.isUnsignedInteger(8)) {
- uint8_t min_val = min_int_data[0];
- uint8_t max_val = max_int_data[0];
- min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
- max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
- } else if (input_element_type.isUnsignedInteger(16)) {
- uint16_t min_val = min_int_data[0];
- uint16_t max_val = max_int_data[0];
- min_val_attr = op_builder->getIntegerAttr(input_element_type, min_val);
- max_val_attr = op_builder->getIntegerAttr(input_element_type, max_val);
- } else {
- llvm::errs()
- << "ERROR: " << get_string(op)
- << " contains unsupported unsigned int element data type.\n";
- return {};
- }
- } else {
- min_val_attr =
- op_builder->getIntegerAttr(input_element_type, min_int_data[0]);
- max_val_attr =
- op_builder->getIntegerAttr(input_element_type, max_int_data[0]);
- }
+ min_val_attr = op_builder->getIntegerAttr(
+ element_type, min_values_attr.getValues<mlir::APInt>()[0]);
+ max_val_attr = op_builder->getIntegerAttr(
+ element_type, max_values_attr.getValues<mlir::APInt>()[0]);
}
- mlir::Operation *mlir_op = op_builder->create<mlir::tosa::ClampOp>(
+ auto mlir_op = op_builder->create<mlir::tosa::ClampOp>(
loc, output_type, input_val, min_val_attr, max_val_attr);
block->push_back(mlir_op);
return std::vector<mlir::Value>({mlir_op->getResult(0)});
@@ -1101,80 +1092,36 @@ TosaMlirOperatorBuilder::build<Op_PAD>(TosaSerializationOperator *op) const {
assert(op->GetAttributeType() ==
Attribute_PadAttribute); // double check attribute type
TosaPadAttribute *attr = static_cast<TosaPadAttribute *>(op->GetAttribute());
-
- float pad_const_fp = 0.0f;
- int32_t pad_const_int = 0;
-
- if (element_type.isa<mlir::FloatType>()) {
- std::vector<float> float_data;
- TosaSerializationHandler::ConvertU8toF32(attr->pad_const(),
- /* size = */ 1, float_data);
- pad_const_fp = float_data[0];
- } else {
- std::vector<int32_t> int32_data;
- TosaSerializationHandler::ConvertU8toI32(attr->pad_const(),
- /* size = */ 1, int32_data);
- pad_const_int = int32_data[0];
+ const auto &pad_const_u8_data = attr->pad_const();
+
+ // check for any value in pad_const_u8_data
+ bool has_pad_const = false;
+ for (auto v : pad_const_u8_data) {
+ if (v != 0) {
+ has_pad_const = true;
+ break;
+ }
}
-
- // todo: int input_zp = attr->pad_input_zp();
-
- mlir::Operation *mlir_op;
- mlir::Value pad_const_value;
-
- bool isBoolType = element_type.isInteger(1);
- // First handle boolean type.
- if (isBoolType) {
- mlir::Type boolType = op_builder->getIntegerType(1);
- auto pad_const_type = mlir::RankedTensorType::get({}, boolType);
- // Treat zero integer is `false`, and any non-zero integner evaluates to
- // `true`.
- bool pad_const = pad_const_int == 0 ? false : true;
- auto pad_const_attr =
- mlir::DenseElementsAttr::get(pad_const_type, {pad_const});
- mlir::Operation *pad_const_op = op_builder->create<mlir::tosa::ConstOp>(
- loc, pad_const_type, pad_const_attr);
-
- block->push_back(pad_const_op);
- pad_const_value = pad_const_op->getResult(0);
- mlir_op = op_builder->create<mlir::tosa::PadOp>(
- loc, output_type, input_val, padding_val, pad_const_value);
-
+ if (!has_pad_const) {
+ // handle the cases where no explicit pad_const input.
+ auto mlir_op = op_builder->create<mlir::tosa::PadOp>(
+ loc, output_type, input_val, padding_val);
block->push_back(mlir_op);
return std::vector<mlir::Value>({mlir_op->getResult(0)});
}
- // Second handle the cases where no explicit pad_const input.
- if (pad_const_int == 0 && pad_const_fp == 0.0f) {
- mlir_op = op_builder->create<mlir::tosa::PadOp>(loc, output_type, input_val,
- padding_val);
- block->push_back(mlir_op);
- return std::vector<mlir::Value>({mlir_op->getResult(0)});
- }
+ // has pad const - create a const op for pad_const input
+ auto pad_const_type = mlir::RankedTensorType::get({}, element_type);
+ auto pad_const_attr = GetConstAttr(pad_const_u8_data, pad_const_type, 1);
- // Then handle explicit numerical pad_const cases.
- if (pad_const_int != 0) {
- assert(pad_const_fp == 0.0f && llvm::isa<IntegerType>(element_type));
- auto pad_const_int_type = mlir::RankedTensorType::get({}, element_type);
- auto pad_const_int_attr =
- mlir::DenseElementsAttr::get(pad_const_int_type, {pad_const_int});
- mlir::Operation *pad_const_int_op = op_builder->create<mlir::tosa::ConstOp>(
- loc, pad_const_int_type, pad_const_int_attr);
- block->push_back(pad_const_int_op);
- pad_const_value = pad_const_int_op->getResult(0);
- } else {
- assert(pad_const_fp != 0 && llvm::isa<FloatType>(element_type));
- auto pad_const_fp_type = mlir::RankedTensorType::get({}, element_type);
- auto pad_const_fp_attr =
- mlir::DenseElementsAttr::get(pad_const_fp_type, {pad_const_fp});
- mlir::Operation *pad_const_fp_op = op_builder->create<mlir::tosa::ConstOp>(
- loc, pad_const_fp_type, pad_const_fp_attr);
- block->push_back(pad_const_fp_op);
- pad_const_value = pad_const_fp_op->getResult(0);
- }
-
- mlir_op = op_builder->create<mlir::tosa::PadOp>(loc, output_type, input_val,
- padding_val, pad_const_value);
+ auto pad_const_op = op_builder->create<mlir::tosa::ConstOp>(
+ loc, pad_const_type, pad_const_attr);
+
+ block->push_back(pad_const_op);
+ mlir::Value pad_const_value = pad_const_op->getResult(0);
+
+ auto mlir_op = op_builder->create<mlir::tosa::PadOp>(
+ loc, output_type, input_val, padding_val, pad_const_value);
block->push_back(mlir_op);
return std::vector<mlir::Value>({mlir_op->getResult(0)});
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 875303e..6553944 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -152,9 +152,29 @@ public:
TosaSerializationHandler *GetTsh() const;
TosaSerializationRegionBuilder *GetRegionBuilder() const;
mlir::LogicalResult GetDataFromAttribute(mlir::Operation &op,
- mlir::Attribute &attr, DType dtype,
+ mlir::Attribute &attr,
+ mlir::Type element_type,
std::vector<uint8_t> &u8_data) const;
+ // populate u8_data with either int64_value or float_value depending on
+ // element_type
+ mlir::LogicalResult
+ GetU8DataFromIntOrFloatValue(int64_t int64_value, float fp_value,
+ mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const;
+
+ // populate u8_data with int_value depending on non-float element_type
+ mlir::LogicalResult
+ GetU8DataFromIntValues(const std::vector<int64_t> &int_values,
+ mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const;
+
+ // populate u8_data with fp_value depending on float element_type
+ mlir::LogicalResult
+ GetU8DataFromFloatValues(const std::vector<float> &fp_values,
+ mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const;
+
private:
std::string GetTensorName(mlir::Value val) const;
std::string GetVariableTensorName(mlir::Operation *op) const;
@@ -303,134 +323,146 @@ std::string TosaSerializationOperatorBuilder::GetVariableTensorName(
}
mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
- mlir::Operation &op, mlir::Attribute &attr, DType type,
+ mlir::Operation &op, mlir::Attribute &attr, mlir::Type element_type,
std::vector<uint8_t> &u8_data) const {
+ if (!element_type.isIntOrFloat()) {
+ return mlir::failure();
+ }
auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
- switch (type) {
- case DType_FP32:
- case DType_BF16:
- case DType_FP16:
- case DType_FP8E4M3:
- case DType_FP8E5M2: {
- std::vector<float> data;
+ // handle float types
+ if (element_type.isa<mlir::FloatType>()) {
+ std::vector<float> fp_data;
auto val_attr = attr.dyn_cast<mlir::FloatAttr>();
if (dense_attr) {
for (auto val : dense_attr.getValues<mlir::APFloat>()) {
- data.push_back(val.convertToFloat());
+ fp_data.push_back(val.convertToFloat());
}
} else if (val_attr) {
- data.push_back((float)val_attr.getValueAsDouble());
+ fp_data.push_back((float)val_attr.getValueAsDouble());
} else {
op.emitOpError("Unknown const attribute");
return mlir::failure();
}
- if (type == DType_FP16) {
- TosaSerializationHandler::ConvertF16toU8(data, u8_data);
- } else {
- // for all other floating types, store as F32 values
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- }
- break;
+ return GetU8DataFromFloatValues(fp_data, element_type, u8_data);
}
- case DType_INT8: {
- std::vector<int8_t> data;
- auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int8_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return mlir::failure();
+ // element_type is integer type
+
+ bool isInt48 = element_type.isInteger(48);
+ std::vector<int64_t> i64_data;
+
+ auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
+ if (dense_attr) {
+ for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
+ int64_t val = isInt48 ? static_cast<int64_t>(valueIt.getLimitedValue())
+ : valueIt.getSExtValue();
+ i64_data.push_back(val);
}
- TosaSerializationHandler::ConvertI8toU8(data, u8_data);
- break;
+ } else if (val_attr) {
+ i64_data.push_back(val_attr.getInt());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return mlir::failure();
}
- case DType_INT16: {
- std::vector<int16_t> data;
- auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int16_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return mlir::failure();
+ return GetU8DataFromIntValues(i64_data, element_type, u8_data);
+}
+
+mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromIntValues(
+ const std::vector<int64_t> &int64_values, mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const {
+ switch (element_type.getIntOrFloatBitWidth()) {
+ case 1: {
+ // bool use bool vec
+ std::vector<bool> bool_values;
+ for (auto v : int64_values) {
+ bool bool_value = v == 0 ? false : true;
+ bool_values.push_back(bool_value);
}
- TosaSerializationHandler::ConvertI16toU8(data, u8_data);
+ TosaSerializationHandler::ConvertBooltoU8(bool_values, u8_data);
break;
}
- case DType_INT32: {
- std::vector<int32_t> data;
- auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int32_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
+ case 4:
+ case 8: {
+ // I4 and I8 use int8_t vec
+ std::vector<int8_t> i8_values;
+ for (auto v : int64_values) {
+ i8_values.push_back(static_cast<int8_t>(v));
+ }
+ if (element_type.isInteger(4)) {
+ TosaSerializationHandler::ConvertI4toU8(i8_values, u8_data);
} else {
- op.emitOpError("Unknown const attribute");
- return mlir::failure();
+ TosaSerializationHandler::ConvertI8toU8(i8_values, u8_data);
}
- TosaSerializationHandler::ConvertI32toU8(data, u8_data);
break;
}
- case DType_INT48: {
- std::vector<int64_t> data;
- auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
-
- if (dense_attr) {
- for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
- uint64_t val = valueIt.getLimitedValue();
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return mlir::failure();
+ case 16: {
+ // I16 use int16_t vec
+ std::vector<int16_t> i16_values;
+ for (auto v : int64_values) {
+ i16_values.push_back(static_cast<int16_t>(v));
}
- TosaSerializationHandler::ConvertI48toU8(data, u8_data);
+ TosaSerializationHandler::ConvertI16toU8(i16_values, u8_data);
break;
}
- case DType_BOOL: {
- std::vector<bool> data;
- auto val_attr = attr.dyn_cast<mlir::BoolAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<bool>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getValue());
- } else {
- op.emitOpError("Unknown const attribute");
- return mlir::failure();
+ case 32: {
+ // I32 use int32_t vec
+ std::vector<int32_t> i32_values;
+ for (auto v : int64_values) {
+ i32_values.push_back(static_cast<int32_t>(v));
}
-
- TosaSerializationHandler::ConvertBooltoU8(data, u8_data);
+ TosaSerializationHandler::ConvertI32toU8(i32_values, u8_data);
+ break;
+ }
+ case 48: {
+ // I48 use int64_t vec
+ TosaSerializationHandler::ConvertI48toU8(int64_values, u8_data);
break;
}
default: {
- op.emitOpError("Unknown element type of const attribute");
+ // unsupported bit widths
return mlir::failure();
}
}
+ return mlir::success();
+}
+mlir::LogicalResult TosaSerializationOperatorBuilder::GetU8DataFromFloatValues(
+ const std::vector<float> &fp_values, mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const {
+ assert(
+ element_type
+ .isa<mlir::FloatType>()); // this should only be called for float type
+ if (element_type.isF16()) {
+ TosaSerializationHandler::ConvertF16toU8(fp_values, u8_data);
+ } else if (element_type.isBF16()) {
+ TosaSerializationHandler::ConvertBF16toU8(fp_values, u8_data);
+ } else if (element_type.isFloat8E4M3FN()) {
+ TosaSerializationHandler::ConvertFP8E4M3toU8(fp_values, u8_data);
+ } else if (element_type.isFloat8E5M2()) {
+ TosaSerializationHandler::ConvertFP8E5M2toU8(fp_values, u8_data);
+ } else if (element_type.isF32()) {
+ TosaSerializationHandler::ConvertF32toU8(fp_values, u8_data);
+ } else {
+ return mlir::failure();
+ }
return mlir::success();
}
+mlir::LogicalResult
+TosaSerializationOperatorBuilder::GetU8DataFromIntOrFloatValue(
+ int64_t int64_value, float fp_value, mlir::Type element_type,
+ std::vector<uint8_t> &u8_data) const {
+ if (element_type.isa<mlir::FloatType>()) {
+ return GetU8DataFromFloatValues({fp_value}, element_type, u8_data);
+ } else {
+ return GetU8DataFromIntValues({int64_value}, element_type, u8_data);
+ }
+}
+
// Main template to catch unimplemented translation.
template <typename T>
TosaSerializationOperator *
@@ -691,9 +723,12 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
}
std::vector<uint8_t> u8_data;
mlir::Attribute attr = op.getAttr(llvm::StringRef("value"));
- DType type = ts->GetDtype();
+ mlir::Type element_type =
+ llvm::cast<mlir::ShapedType>(op.getResult(0).getType()).getElementType();
- if (GetDataFromAttribute(op, attr, type, u8_data).failed()) {
+ if (GetDataFromAttribute(op, attr, element_type, u8_data).failed()) {
+ op.emitOpError("ERROR: GetDataFromAttribute() fails when building value of "
+ "const tensor");
return nullptr;
}
@@ -977,26 +1012,32 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ClampOp>(
}
std::vector<uint8_t> min_val, max_val;
+ float min_fp, max_fp;
+ int64_t min_int, max_int;
+
if (input_element_type.isa<mlir::FloatType>()) {
- auto min_fp =
+ min_fp =
mlir::cast<mlir::FloatAttr>(min_val_attr).getValue().convertToFloat();
- auto max_fp =
+ max_fp =
mlir::cast<mlir::FloatAttr>(max_val_attr).getValue().convertToFloat();
- TosaSerializationHandler::ConvertF32toU8({min_fp}, min_val);
- TosaSerializationHandler::ConvertF32toU8({max_fp}, max_val);
+ min_int = max_int = 0;
} else {
- int32_t min_int = 0;
- int32_t max_int = 0;
- if (input_element_type.isUnsignedInteger()) {
- min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getUInt();
- max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getUInt();
- } else {
- assert(input_element_type.isa<mlir::IntegerType>());
- min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt();
- max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt();
- }
- TosaSerializationHandler::ConvertI32toU8({min_int}, min_val);
- TosaSerializationHandler::ConvertI32toU8({max_int}, max_val);
+ assert(input_element_type.isa<mlir::IntegerType>());
+ min_int = mlir::cast<mlir::IntegerAttr>(min_val_attr).getInt();
+ max_int = mlir::cast<mlir::IntegerAttr>(max_val_attr).getInt();
+ min_fp = max_fp = 0.f;
+ }
+
+ if (GetU8DataFromIntOrFloatValue(min_int, min_fp, input_element_type, min_val)
+ .failed()) {
+ op.emitOpError("Failed to serialize min value");
+ return nullptr;
+ }
+
+ if (GetU8DataFromIntOrFloatValue(max_int, max_fp, input_element_type, max_val)
+ .failed()) {
+ op.emitOpError("Failed to serialize max value");
+ return nullptr;
}
std::string input_name = GetTensorName(op.getOperand(0));
@@ -1150,11 +1191,14 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::PadOp>(
std::vector<uint8_t> pad_const;
mlir::Type input_element_type =
llvm::cast<mlir::ShapedType>(op.getOperand(0).getType()).getElementType();
- if (input_element_type.isa<mlir::FloatType>()) {
- TosaSerializationHandler::ConvertF32toU8({pad_const_fp}, pad_const);
- } else {
- TosaSerializationHandler::ConvertI32toU8({pad_const_int}, pad_const);
+
+ if (GetU8DataFromIntOrFloatValue(pad_const_int, pad_const_fp,
+ input_element_type, pad_const)
+ .failed()) {
+ op.emitOpError("Failed to serialize pad_const value");
+ return nullptr;
}
+
TosaPadAttribute attribute(pad_const);
TosaSerializationOperator *tyop = new TosaSerializationOperator(
@@ -1806,11 +1850,11 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
// zeros
mlir::Attribute initial_value = op->getAttr("initial_value");
std::vector<uint8_t> u8_data;
- DType element_type = Type2DType(tensor_type.getElementType());
if (initial_value) {
if (initial_value.isa<mlir::DenseElementsAttr>()) {
if (op_builder
- .GetDataFromAttribute(*op, initial_value, element_type, u8_data)
+ .GetDataFromAttribute(*op, initial_value,
+ tensor_type.getElementType(), u8_data)
.failed()) {
llvm::errs() << "ERROR: GetDataFromAttribute() fails when building "
"initial_value of variable tensor\n";
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject ad78daaf0fa1e41742cbed314459c3dbbb483c2
+Subproject 57d781883142db8a45fe98ac1a1dfacc49cba78