aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-02-06 21:32:52 +0000
committerTai Ly <tai.ly@arm.com>2024-02-21 18:54:47 +0000
commit86db8bc37237c68a30a917ff77cbcd7784879ae4 (patch)
tree23e3417da3c92fb51cc9812d1468cd044ff74e2f /src/TosaSerialize.cpp
parent05a243c3b8d01fe57aca9a5b9c1b835ee5d1e6b2 (diff)
downloadtosa_mlir_translator-86db8bc37237c68a30a917ff77cbcd7784879ae4.tar.gz
[tosa_mlir_translator] Add FP8 support
Add serialization and deserialization support for FP8 data types. Also, added deserialization support for BF16 constants. BF16 and FP8 constants are serialized and deserialized as F32 values. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I919acd82dc5e0b85024b6403d9623eaa26151aef
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp252
1 files changed, 55 insertions, 197 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 04709b7..fc6655b 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -77,6 +77,10 @@ static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
static DType Type2DType(mlir::Type element_type) {
if (element_type.isF64() || element_type.isF32()) {
return DType_FP32;
+ } else if (element_type.isFloat8E5M2()) {
+ return DType_FP8E5M2;
+ } else if (element_type.isFloat8E4M3FN()) {
+ return DType_FP8E4M3;
} else if (element_type.isF16()) {
return DType_FP16;
} else if (element_type.isBF16()) {
@@ -101,29 +105,6 @@ static DType Type2DType(mlir::Type element_type) {
return DType_UNKNOWN;
}
-// Returns number of bits TOSA flatbuffer store in tensor raw data array
-uint64_t GetDTypeSize(DType dtype) {
- switch (dtype) {
- case DType_INT4:
- return 4;
- case DType_BOOL:
- case DType_UINT8:
- case DType_INT8:
- return 8;
- case DType_INT16:
- return 16;
- case DType_FP32:
- case DType_INT32:
- return 32;
- case DType_INT48:
- return 48;
- default:
- llvm::errs() << "WARNING: unsupported dtype " << EnumNamesDType()[dtype]
- << "\n";
- return 1;
- }
-}
-
static DType Type2PoolAccumDType(mlir::Type element_type) {
// def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
if (element_type.isF32()) {
@@ -168,7 +149,8 @@ public:
TosaSerializationOperator *build(mlir::Operation &op) const;
TosaSerializationHandler *GetTsh() const;
TosaSerializationRegionBuilder *GetRegionBuilder() const;
- mlir::LogicalResult GetDataFromAttribute(mlir::Attribute &attr, DType dtype,
+ mlir::LogicalResult GetDataFromAttribute(mlir::Operation &op,
+ mlir::Attribute &attr, DType dtype,
std::vector<uint8_t> &u8_data) const;
private:
@@ -319,24 +301,39 @@ std::string TosaSerializationOperatorBuilder::GetVariableTensorName(
}
mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
- mlir::Attribute &attr, DType type, std::vector<uint8_t> &u8_data) const {
+ mlir::Operation &op, mlir::Attribute &attr, DType type,
+ std::vector<uint8_t> &u8_data) const {
auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
- if (type == DType_FP32) {
+
+ switch (type) {
+ case DType_FP32:
+ case DType_BF16:
+ case DType_FP16:
+ case DType_FP8E4M3:
+ case DType_FP8E5M2: {
std::vector<float> data;
auto val_attr = attr.dyn_cast<mlir::FloatAttr>();
if (dense_attr) {
- for (auto val : dense_attr.getValues<float>()) {
- data.push_back(val);
+ for (auto val : dense_attr.getValues<mlir::APFloat>()) {
+ data.push_back(val.convertToFloat());
}
} else if (val_attr) {
data.push_back((float)val_attr.getValueAsDouble());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- } else if (type == DType_INT8) {
+
+ if (type == DType_FP16) {
+ TosaSerializationHandler::ConvertF16toU8(data, u8_data);
+ } else {
+ // for all other floating types, store as F32 values
+ TosaSerializationHandler::ConvertF32toU8(data, u8_data);
+ }
+ break;
+ }
+ case DType_INT8: {
std::vector<int8_t> data;
auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
@@ -347,11 +344,13 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
} else if (val_attr) {
data.push_back(val_attr.getInt());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertI8toU8(data, u8_data);
- } else if (type == DType_INT16) {
+ break;
+ }
+ case DType_INT16: {
std::vector<int16_t> data;
auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
@@ -362,11 +361,13 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
} else if (val_attr) {
data.push_back(val_attr.getInt());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertI16toU8(data, u8_data);
- } else if (type == DType_INT32) {
+ break;
+ }
+ case DType_INT32: {
std::vector<int32_t> data;
auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
@@ -377,28 +378,32 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
} else if (val_attr) {
data.push_back(val_attr.getInt());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertI32toU8(data, u8_data);
- } else if (type == DType_INT48) {
+ break;
+ }
+ case DType_INT48: {
std::vector<int64_t> data;
auto val_attr = attr.dyn_cast<mlir::IntegerAttr>();
if (dense_attr) {
- for (auto val : dense_attr.getValues<int64_t>()) {
+ for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
+ uint64_t val = valueIt.getLimitedValue();
data.push_back(val);
}
} else if (val_attr) {
data.push_back(val_attr.getInt());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertI48toU8(data, u8_data);
- } else if (type == DType_BOOL) {
+ break;
+ }
+ case DType_BOOL: {
std::vector<bool> data;
-
auto val_attr = attr.dyn_cast<mlir::BoolAttr>();
if (dense_attr) {
@@ -408,15 +413,18 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute(
} else if (val_attr) {
data.push_back(val_attr.getValue());
} else {
- llvm::errs() << "Unknown const attribute\n";
+ op.emitOpError("Unknown const attribute");
return mlir::failure();
}
TosaSerializationHandler::ConvertBooltoU8(data, u8_data);
- } else {
- llvm::errs() << "Unknown element type of const attribute\n";
+ break;
+ }
+ default: {
+ op.emitOpError("Unknown element type of const attribute");
return mlir::failure();
}
+ }
return mlir::success();
}
@@ -669,19 +677,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
return nullptr;
}
-#if 0
- // Gracefully handle constants of "constant unit" type which have no value
- // by creating a numpy value of 0.
- auto unit_val = op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::UnitAttr>();
-
- if (unit_val)
- {
- std::vector<float> data = { 0.0 };
- type = DType_FP32;
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- }
-#endif
-
// Update tensor.data array with Const value attribute
mlir::Attribute value_attr = op.getAttr("value");
if (!value_attr) {
@@ -689,139 +684,10 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
return nullptr;
}
std::vector<uint8_t> u8_data;
-
+ mlir::Attribute attr = op.getAttr(llvm::StringRef("value"));
DType type = ts->GetDtype();
- if (type == DType_FP32) {
- std::vector<float> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<mlir::APFloat>()) {
- data.push_back(val.convertToFloat());
- }
- } else if (val_attr) {
- data.push_back((float)val_attr.getValueAsDouble());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertF32toU8(data, u8_data);
- } else if (type == DType_FP16) {
- std::vector<float> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<mlir::APFloat>()) {
- data.push_back(val.convertToFloat());
- }
- } else if (val_attr) {
- data.push_back((float)val_attr.getValueAsDouble());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertF16toU8(data, u8_data);
- } else if (type == DType_INT8) {
- std::vector<int8_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int8_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI8toU8(data, u8_data);
- } else if (type == DType_INT16) {
- std::vector<int16_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int16_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI16toU8(data, u8_data);
- } else if (type == DType_INT32) {
- std::vector<int32_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<int32_t>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI32toU8(data, u8_data);
- } else if (type == DType_INT48) {
- std::vector<int64_t> data;
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>();
-
- if (dense_attr) {
- for (auto valueIt : dense_attr.getValues<mlir::APInt>()) {
- uint64_t val = valueIt.getLimitedValue();
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getInt());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
- TosaSerializationHandler::ConvertI48toU8(data, u8_data);
- } else if (type == DType_BOOL) {
- std::vector<bool> data;
-
- auto dense_attr = op.getAttr(llvm::StringRef("value"))
- .dyn_cast<mlir::DenseElementsAttr>();
- auto val_attr =
- op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::BoolAttr>();
-
- if (dense_attr) {
- for (auto val : dense_attr.getValues<bool>()) {
- data.push_back(val);
- }
- } else if (val_attr) {
- data.push_back(val_attr.getValue());
- } else {
- op.emitOpError("Unknown const attribute");
- return nullptr;
- }
-
- TosaSerializationHandler::ConvertBooltoU8(data, u8_data);
- } else {
- op.emitOpError("Unknown element type of const attribute");
+ if (GetDataFromAttribute(op, attr, type, u8_data).failed()) {
return nullptr;
}
@@ -1898,7 +1764,7 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
if (initial_value) {
if (initial_value.isa<mlir::DenseElementsAttr>()) {
if (op_builder
- .GetDataFromAttribute(initial_value, element_type, u8_data)
+ .GetDataFromAttribute(*op, initial_value, element_type, u8_data)
.failed()) {
llvm::errs() << "ERROR: GetDataFromAttribute() fails when building "
"initial_value of variable tensor\n";
@@ -1909,14 +1775,6 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock(
return mlir::failure();
}
} else {
- uint64_t num_elements = 1;
- for (int64_t dim : tensor_type.getShape()) {
- num_elements *= dim;
- }
- uint64_t num_bits = num_elements * GetDTypeSize(element_type);
- uint64_t num_bytes =
- (num_bits % 8 == 0) ? (num_bits / 8) : (num_bits / 8) + 1;
- // std::fill_n(u8_data.begin(), num_bytes, 0);
TosaSerializationHandler::ForceAlignTensorData(u8_data);
}
ser_tensor->SetData(u8_data);