aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-02-06 21:32:52 +0000
committerTai Ly <tai.ly@arm.com>2024-02-21 18:54:47 +0000
commit86db8bc37237c68a30a917ff77cbcd7784879ae4 (patch)
tree23e3417da3c92fb51cc9812d1468cd044ff74e2f
parent05a243c3b8d01fe57aca9a5b9c1b835ee5d1e6b2 (diff)
downloadtosa_mlir_translator-86db8bc37237c68a30a917ff77cbcd7784879ae4.tar.gz
[tosa_mlir_translator] Add FP8 support
Add serialization and deserialization support for FP8 data types. Also, added deserialization support for BF16 constants. BF16 and FP8 constants are serialized and deserialized as F32 values. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I919acd82dc5e0b85024b6403d9623eaa26151aef
-rw-r--r--src/TosaDeserialize.cpp13
-rw-r--r--src/TosaSerialize.cpp252
m---------third_party/serialization_lib0
3 files changed, 66 insertions, 199 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 3c7db8e..87c363f 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -138,6 +138,12 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
case DType_BF16:
element_type = op_builder->getBF16Type();
break;
+ case DType_FP8E4M3:
+ element_type = op_builder->getFloat8E4M3FNType();
+ break;
+ case DType_FP8E5M2:
+ element_type = op_builder->getFloat8E5M2Type();
+ break;
case DType_SHAPE:
llvm::errs()
<< "ERROR: Cannot construct RankedTensorType out of tosa.shape type \n";
@@ -172,7 +178,11 @@ ConstructConstAttr(const mlir::RankedTensorType &output_type,
}
mlir::DenseElementsAttr value_attr;
switch (ts->GetDtype()) {
- case DType_FP32: {
+ case DType_FP32:
+ case DType_BF16:
+ case DType_FP8E4M3:
+ case DType_FP8E5M2: {
+ // for FP32, FP16 and FP8 types, value attributes are stored as FP32 values
std::vector<float> float_data;
TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data);
value_attr =
@@ -236,7 +246,6 @@ ConstructConstAttr(const mlir::RankedTensorType &output_type,
}
case DType_UINT8:
case DType_UINT16:
- case DType_BF16:
default: {
llvm::errs() << "ERROR: " << op_name
<< " contains unsupported element type\n";
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 04709b7..fc6655b 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -77,6 +77,10 @@ static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
static DType Type2DType(mlir::Type element_type) {
if (element_type.isF64() || element_type.isF32()) {
return DType_FP32;
+ } else if (element_type.isFloat8E5M2()) {
+ return DType_FP8E5M2;
+ } else if (element_type.isFloat8E4M3FN()) {
+ return DType_FP8E4M3;
} else if (element_type.isF16()) {
return DType_FP16;
} else if (element_type.isBF16()) {
@@ -101,29 +105,6 @@ static DType Type2DType(mlir::Type element_type) {
return DType_UNKNOWN;
}
-// Returns number of bits TOSA flatbuffer store in tensor raw data array
-uint64_t GetDTypeSize(DType dtype) {
- switch (dtype) {
- case DType_INT4:
- return 4;
- case DType_BOOL:
- case DType_UINT8:
- case DType_INT8:
- return 8;
- case DType_INT16:
- return 16;
- case DType_FP32:
- case DType_INT32:
- return 32;
- case DType_INT48:
- return 48;
- default:
- llvm::errs() << "WARNING: unsupported dtype " << EnumNamesDType()[dtype]
- << "\n";
- return 1;
- }
-}
-
static DType Type2PoolAccumDType(mlir::Type element_type) {
// def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
if (element_type.isF32()) {
@@ -168,7 +149,8 @@ public:
TosaSerializationOperator *build(mlir::Operation &op) const;
TosaSerializationHandler *GetTsh() const;
TosaSerializationRegionBuilder *GetRegionBuilder() const;
- mlir::LogicalResult GetDataFromAttribute(mlir::Attribute &attr, DType dtype,
+ mlir::LogicalResult GetDataFromAttribute(mlir::Operation &op,
+ mlir::Attribute &attr, DType dtype,
std::vector<uint8_t> &u8_data) const;
private:
@@ -319,24 +301,39 @@ std::string TosaSerializationOperatorBuilder::GetVariableTensorName(
}
mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
- mlir::Attribute &attr, DType type, std::vector<uint8_t> &u8_data) const {
+ mlir::Operation &op, mlir::Attribute &attr, DType type,
+ std::vector<uint8_t> &u8_data) const {
auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
- if (type == DType_FP32) {
+
+ switch (type) {
+ case DType_FP32:
+ case DType_BF16:
+ case DType_FP16:
+ case DType_FP8E4M3:
+ case DType_FP8E5M2: {
std::vector<float> data;
auto val_attr = attr.dyn_cast<mlir::FloatAttr>();
if (dense_attr) {
- for (auto val : dense_attr.getValues<float>()) {
- data.push_back(val);
+ for (auto val : dense_attr.getValues<mlir::APFloat>()) {
+ data.push_back(val.convertToFloat());
}
} else if (val_attr) {
data.push_back((float)val_attr.getValueAsDouble());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- } else if (type == DType_INT8) {
+
+ if (type == DType_FP16) {
+ TosaSerializationHandler::ConvertF16toU8(data, u8_data);
+ } else {
+ // for all other floating types, store as F32 values
+ TosaSerializationHandler::ConvertF32toU8(data, u8_data);
+ }
+ break;
+ }
+ case DType_INT8: {
std::vector<int8_t> data;
auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
@@ -347,11 +344,13 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
} else if (val_attr) {
data.push_back(val_attr.getInt());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertI8toU8(data, u8_data);
- } else if (type == DType_INT16) {
+ break;
+ }
+ case DType_INT16: {
std::vector<int16_t> data;
auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
@@ -362,11 +361,13 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
} else if (val_attr) {
data.push_back(val_attr.getInt());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertI16toU8(data, u8_data);
- } else if (type == DType_INT32) {
+ break;
+ }
+ case DType_INT32: {
std::vector<int32_t> data;
auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
@@ -377,28 +378,32 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
} else if (val_attr) {
data.push_back(val_attr.getInt());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertI32toU8(data, u8_data);
- } else if (type == DType_INT48) {
+ break;
+ }
+ case DType_INT48: {
std::vector<int64_t> data;
auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
if (dense_attr) {
- for (auto val : dense_attr.getValues<int64_t>()) {
+ for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
+ uint64_t val = valueIt.getLimitedValue();
data.push_back(val);
}
} else if (val_attr) {
data.push_back(val_attr.getInt());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertI48toU8(data, u8_data);
- } else if (type == DType_BOOL) {
+ break;
+ }
+ case DType_BOOL: {
std::vector<bool> data;
-
auto val_attr = attr.dyn_cast<mlir::BoolAttr>();
if (dense_attr) {
@@ -408,15 +413,18 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
} else if (val_attr) {
data.push_back(val_attr.getValue());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertBooltoU8(data, u8_data);
- } else {
- llvm::errs() << "Unknown element type of const attribute\n";
+ break;
+ }
+ default: {
+ op.emitOpError("Unknown element type of const attribute");
return mlir::failure();
}
+ }
return mlir::success();
}
@@ -669,19 +677,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
return nullptr;
}
-#if 0
- // Gracefully handle constants of "constant unit" type which have no value
- // by creating a numpy value of 0.
- auto unit_val = op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::UnitAttr>();
-
- if (unit_val)
- {
- std::vector<float> data = { 0.0 };
- type = DType_FP32;
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- }
-#endif
-
// Update tensor.data array with Const value attribute
mlir::Attribute value_attr = op.getAttr("value");
if (!value_attr) {
@@ -689,139 +684,10 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
return nullptr;
}
std::vector<uint8_t> u8_data;
-
+ mlir::Attribute attr = op.getAttr(llvm::StringRef("value"));
DType type = ts->GetDtype();
- if (type == DType_FP32) {
- std::vector<float> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<mlir::APFloat>()) {
- data.push_back(val.convertToFloat());
- }
- } else if (val_attr) {
- data.push_back((float)val_attr.getValueAsDouble());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- } else if (type == DType_FP16) {
- std::vector<float> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<mlir::APFloat>()) {
- data.push_back(val.convertToFloat());
- }
- } else if (val_attr) {
- data.push_back((float)val_attr.getValueAsDouble());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertF16toU8(data, u8_data);
- } else if (type == DType_INT8) {
- std::vector<int8_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int8_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI8toU8(data, u8_data);
- } else if (type == DType_INT16) {
- std::vector<int16_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int16_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI16toU8(data, u8_data);
- } else if (type == DType_INT32) {
- std::vector<int32_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int32_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI32toU8(data, u8_data);
- } else if (type == DType_INT48) {
- std::vector<int64_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
-
- if (dense_attr) {
- for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
- uint64_t val = valueIt.getLimitedValue();
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI48toU8(data, u8_data);
- } else if (type == DType_BOOL) {
- std::vector<bool> data;
-
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::BoolAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<bool>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getValue());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
-
- TosaSerializationHandler::ConvertBooltoU8(data, u8_data);
- } else {
- op.emitOpError("Unknown element type of const attribute");
+ if (GetDataFromAttribute(op, attr, type, u8_data).failed()) {
return nullptr;
}
@@ -1898,7 +1764,7 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
if (initial_value) {
if (initial_value.isa<mlir::DenseElementsAttr>()) {
if (op_builder
- .GetDataFromAttribute(initial_value, element_type, u8_data)
+ .GetDataFromAttribute(*op, initial_value, element_type, u8_data)
.failed()) {
llvm::errs() << "ERROR: GetDataFromAttribute() fails when building "
"initial_value of variable tensor\n";
@@ -1909,14 +1775,6 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
return mlir::failure();
}
} else {
- uint64_t num_elements = 1;
- for (int64_t dim : tensor_type.getShape()) {
- num_elements *= dim;
- }
- uint64_t num_bits = num_elements * GetDTypeSize(element_type);
- uint64_t num_bytes =
- (num_bits % 8 == 0) ? (num_bits / 8) : (num_bits / 8) + 1;
- // std::fill_n(u8_data.begin(), num_bytes, 0);
TosaSerializationHandler::ForceAlignTensorData(u8_data);
}
ser_tensor->SetData(u8_data);
diff --git a/third_party/serialization_lib b/third_party/serialization_lib
-Subproject 8137a4369acefa4c01f08db95a2b1b290e8dd70
+Subproject a029f1f02707f40f6990df53fd4f56684490d58