From 86db8bc37237c68a30a917ff77cbcd7784879ae4 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Tue, 6 Feb 2024 21:32:52 +0000 Subject: [tosa_mlir_translator] Add FP8 support Add serialization and deserialization support for FP8 data types. Also, added deserialization support for BF16 constants. BF16 and FP8 constants are serialized and deserialized as F32 values. Signed-off-by: Tai Ly Change-Id: I919acd82dc5e0b85024b6403d9623eaa26151aef --- src/TosaSerialize.cpp | 252 +++++++++++--------------------------------------- 1 file changed, 55 insertions(+), 197 deletions(-) (limited to 'src/TosaSerialize.cpp') diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 04709b7..fc6655b 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -77,6 +77,10 @@ static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { static DType Type2DType(mlir::Type element_type) { if (element_type.isF64() || element_type.isF32()) { return DType_FP32; + } else if (element_type.isFloat8E5M2()) { + return DType_FP8E5M2; + } else if (element_type.isFloat8E4M3FN()) { + return DType_FP8E4M3; } else if (element_type.isF16()) { return DType_FP16; } else if (element_type.isBF16()) { @@ -101,29 +105,6 @@ static DType Type2DType(mlir::Type element_type) { return DType_UNKNOWN; } -// Returns number of bits TOSA flatbuffer store in tensor raw data array -uint64_t GetDTypeSize(DType dtype) { - switch (dtype) { - case DType_INT4: - return 4; - case DType_BOOL: - case DType_UINT8: - case DType_INT8: - return 8; - case DType_INT16: - return 16; - case DType_FP32: - case DType_INT32: - return 32; - case DType_INT48: - return 48; - default: - llvm::errs() << "WARNING: unsupported dtype " << EnumNamesDType()[dtype] - << "\n"; - return 1; - } -} - static DType Type2PoolAccumDType(mlir::Type element_type) { // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>; if (element_type.isF32()) { @@ -168,7 +149,8 @@ public: TosaSerializationOperator *build(mlir::Operation &op) const; TosaSerializationHandler *GetTsh() const; TosaSerializationRegionBuilder *GetRegionBuilder() const; - mlir::LogicalResult GetDataFromAttribute(mlir::Attribute &attr, DType dtype, + mlir::LogicalResult GetDataFromAttribute(mlir::Operation &op, + mlir::Attribute &attr, DType dtype, std::vector &u8_data) const; private: @@ -319,24 +301,39 @@ std::string TosaSerializationOperatorBuilder::GetVariableTensorName( } mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( - mlir::Attribute &attr, DType type, std::vector &u8_data) const { + mlir::Operation &op, mlir::Attribute &attr, DType type, + std::vector &u8_data) const { auto dense_attr = attr.dyn_cast(); - if (type == DType_FP32) { + + switch (type) { + case DType_FP32: + case DType_BF16: + case DType_FP16: + case DType_FP8E4M3: + case DType_FP8E5M2: { std::vector data; auto val_attr = attr.dyn_cast(); if (dense_attr) { - for (auto val : dense_attr.getValues()) { - data.push_back(val); + for (auto val : dense_attr.getValues()) { + data.push_back(val.convertToFloat()); } } else if (val_attr) { data.push_back((float)val_attr.getValueAsDouble()); } else { - llvm::errs() << "Unknown const attribute\n"; + op.emitOpError("Unknown const attribute"); return mlir::failure(); } - TosaSerializationHandler::ConvertF32toU8(data, u8_data); - } else if (type == DType_INT8) { + + if (type == DType_FP16) { + TosaSerializationHandler::ConvertF16toU8(data, u8_data); + } else { + // for all other floating types, store as F32 values + TosaSerializationHandler::ConvertF32toU8(data, u8_data); + } + break; + } + case DType_INT8: { std::vector data; auto val_attr = attr.dyn_cast(); @@ -347,11 +344,13 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( } else if (val_attr) { data.push_back(val_attr.getInt()); } else { - llvm::errs() << "Unknown const attribute\n"; + op.emitOpError("Unknown const attribute"); return mlir::failure(); } TosaSerializationHandler::ConvertI8toU8(data, u8_data); - } else if (type == DType_INT16) { + break; + } + case DType_INT16: { std::vector data; auto val_attr = attr.dyn_cast(); @@ -362,11 +361,13 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( } else if (val_attr) { data.push_back(val_attr.getInt()); } else { - llvm::errs() << "Unknown const attribute\n"; + op.emitOpError("Unknown const attribute"); return mlir::failure(); } TosaSerializationHandler::ConvertI16toU8(data, u8_data); - } else if (type == DType_INT32) { + break; + } + case DType_INT32: { std::vector data; auto val_attr = attr.dyn_cast(); @@ -377,28 +378,32 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( } else if (val_attr) { data.push_back(val_attr.getInt()); } else { - llvm::errs() << "Unknown const attribute\n"; + op.emitOpError("Unknown const attribute"); return mlir::failure(); } TosaSerializationHandler::ConvertI32toU8(data, u8_data); - } else if (type == DType_INT48) { + break; + } + case DType_INT48: { std::vector data; auto val_attr = attr.dyn_cast(); if (dense_attr) { - for (auto val : dense_attr.getValues()) { + for (auto valueIt : dense_attr.getValues()) { + uint64_t val = valueIt.getLimitedValue(); data.push_back(val); } } else if (val_attr) { data.push_back(val_attr.getInt()); } else { - llvm::errs() << "Unknown const attribute\n"; + op.emitOpError("Unknown const attribute"); return mlir::failure(); } TosaSerializationHandler::ConvertI48toU8(data, u8_data); - } else if (type == DType_BOOL) { + break; + } + case DType_BOOL: { std::vector data; - auto val_attr = attr.dyn_cast(); if (dense_attr) { @@ -408,15 +413,18 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( } else if (val_attr) { data.push_back(val_attr.getValue()); } else { - llvm::errs() << "Unknown const attribute\n"; + op.emitOpError("Unknown const attribute"); return mlir::failure(); } TosaSerializationHandler::ConvertBooltoU8(data, u8_data); - } else { - llvm::errs() << "Unknown element type of const attribute\n"; + break; + } + default: { + op.emitOpError("Unknown element type of const attribute"); return mlir::failure(); } + } return mlir::success(); } @@ -669,19 +677,6 @@ TosaSerializationOperatorBuilder::build( return nullptr; } -#if 0 - // Gracefully handle constants of "constant unit" type which have no value - // by creating a numpy value of 0. - auto unit_val = op.getAttr(llvm::StringRef("value")).dyn_cast(); - - if (unit_val) - { - std::vector data = { 0.0 }; - type = DType_FP32; - TosaSerializationHandler::ConvertF32toU8(data, u8_data); - } -#endif - // Update tensor.data array with Const value attribute mlir::Attribute value_attr = op.getAttr("value"); if (!value_attr) { @@ -689,139 +684,10 @@ TosaSerializationOperatorBuilder::build( return nullptr; } std::vector u8_data; - + mlir::Attribute attr = op.getAttr(llvm::StringRef("value")); DType type = ts->GetDtype(); - if (type == DType_FP32) { - std::vector data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast(); - - if (dense_attr) { - for (auto val : dense_attr.getValues()) { - data.push_back(val.convertToFloat()); - } - } else if (val_attr) { - data.push_back((float)val_attr.getValueAsDouble()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertF32toU8(data, u8_data); - } else if (type == DType_FP16) { - std::vector data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast(); - - if (dense_attr) { - for (auto val : dense_attr.getValues()) { - data.push_back(val.convertToFloat()); - } - } else if (val_attr) { - data.push_back((float)val_attr.getValueAsDouble()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertF16toU8(data, u8_data); - } else if (type == DType_INT8) { - std::vector data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast(); - if (dense_attr) { - for (auto val : dense_attr.getValues()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertI8toU8(data, u8_data); - } else if (type == DType_INT16) { - std::vector data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast(); - - if (dense_attr) { - for (auto val : dense_attr.getValues()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertI16toU8(data, u8_data); - } else if (type == DType_INT32) { - std::vector data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast(); - - if (dense_attr) { - for (auto val : dense_attr.getValues()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertI32toU8(data, u8_data); - } else if (type == DType_INT48) { - std::vector data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast(); - - if (dense_attr) { - for (auto valueIt : dense_attr.getValues()) { - uint64_t val = valueIt.getLimitedValue(); - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertI48toU8(data, u8_data); - } else if (type == DType_BOOL) { - std::vector data; - - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast(); - - if (dense_attr) { - for (auto val : dense_attr.getValues()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getValue()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - - TosaSerializationHandler::ConvertBooltoU8(data, u8_data); - } else { - op.emitOpError("Unknown element type of const attribute"); + if (GetDataFromAttribute(op, attr, type, u8_data).failed()) { return nullptr; } @@ -1898,7 +1764,7 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( if (initial_value) { if (initial_value.isa()) { if (op_builder - .GetDataFromAttribute(initial_value, element_type, u8_data) + .GetDataFromAttribute(*op, initial_value, element_type, u8_data) .failed()) { llvm::errs() << "ERROR: GetDataFromAttribute() fails when building " "initial_value of variable tensor\n"; @@ -1909,14 +1775,6 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( return mlir::failure(); } } else { - uint64_t num_elements = 1; - for (int64_t dim : tensor_type.getShape()) { - num_elements *= dim; - } - uint64_t num_bits = num_elements * GetDTypeSize(element_type); - uint64_t num_bytes = - (num_bits % 8 == 0) ? (num_bits / 8) : (num_bits / 8) + 1; - // std::fill_n(u8_data.begin(), num_bytes, 0); TosaSerializationHandler::ForceAlignTensorData(u8_data); } ser_tensor->SetData(u8_data); -- cgit v1.2.1