diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/TosaDeserialize.cpp | 13 | ||||
-rw-r--r-- | src/TosaSerialize.cpp | 252 |
2 files changed, 66 insertions, 199 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 3c7db8e..87c363f 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -138,6 +138,12 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, case DType_BF16: element_type = op_builder->getBF16Type(); break; + case DType_FP8E4M3: + element_type = op_builder->getFloat8E4M3FNType(); + break; + case DType_FP8E5M2: + element_type = op_builder->getFloat8E5M2Type(); + break; case DType_SHAPE: llvm::errs() << "ERROR: Cannot construct RankedTensorType out of tosa.shape type \n"; @@ -172,7 +178,11 @@ ConstructConstAttr(const mlir::RankedTensorType &output_type, } mlir::DenseElementsAttr value_attr; switch (ts->GetDtype()) { - case DType_FP32: { + case DType_FP32: + case DType_BF16: + case DType_FP8E4M3: + case DType_FP8E5M2: { + // for FP32, FP16 and FP8 types, value attributes are stored as FP32 values std::vector<float> float_data; TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data); value_attr = @@ -236,7 +246,6 @@ ConstructConstAttr(const mlir::RankedTensorType &output_type, } case DType_UINT8: case DType_UINT16: - case DType_BF16: default: { llvm::errs() << "ERROR: " << op_name << " contains unsupported element type\n"; diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 04709b7..fc6655b 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -77,6 +77,10 @@ static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { static DType Type2DType(mlir::Type element_type) { if (element_type.isF64() || element_type.isF32()) { return DType_FP32; + } else if (element_type.isFloat8E5M2()) { + return DType_FP8E5M2; + } else if (element_type.isFloat8E4M3FN()) { + return DType_FP8E4M3; } else if (element_type.isF16()) { return DType_FP16; } else if (element_type.isBF16()) { @@ -101,29 +105,6 @@ static DType Type2DType(mlir::Type element_type) { return DType_UNKNOWN; } -// Returns number of bits TOSA flatbuffer store in tensor raw data array -uint64_t GetDTypeSize(DType dtype) { - switch (dtype) { - case DType_INT4: - return 4; - case DType_BOOL: - case DType_UINT8: - case DType_INT8: - return 8; - case DType_INT16: - return 16; - case DType_FP32: - case DType_INT32: - return 32; - case DType_INT48: - return 48; - default: - llvm::errs() << "WARNING: unsupported dtype " << EnumNamesDType()[dtype] - << "\n"; - return 1; - } -} - static DType Type2PoolAccumDType(mlir::Type element_type) { // def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>; if (element_type.isF32()) { @@ -168,7 +149,8 @@ public: TosaSerializationOperator *build(mlir::Operation &op) const; TosaSerializationHandler *GetTsh() const; TosaSerializationRegionBuilder *GetRegionBuilder() const; - mlir::LogicalResult GetDataFromAttribute(mlir::Attribute &attr, DType dtype, + mlir::LogicalResult GetDataFromAttribute(mlir::Operation &op, + mlir::Attribute &attr, DType dtype, std::vector<uint8_t> &u8_data) const; private: @@ -319,24 +301,39 @@ std::string TosaSerializationOperatorBuilder::GetVariableTensorName( } mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( - mlir::Attribute &attr, DType type, std::vector<uint8_t> &u8_data) const { + mlir::Operation &op, mlir::Attribute &attr, DType type, + std::vector<uint8_t> &u8_data) const { auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>(); - if (type == DType_FP32) { + + switch (type) { + case DType_FP32: + case DType_BF16: + case DType_FP16: + case DType_FP8E4M3: + case DType_FP8E5M2: { std::vector<float> data; auto val_attr = attr.dyn_cast<mlir::FloatAttr>(); if (dense_attr) { - for (auto val : dense_attr.getValues<float>()) { - data.push_back(val); + for (auto val : dense_attr.getValues<mlir::APFloat>()) { + data.push_back(val.convertToFloat()); } } else if (val_attr) { data.push_back((float)val_attr.getValueAsDouble()); } else { - llvm::errs() << "Unknown const attribute\n"; + op.emitOpError("Unknown const attribute"); return mlir::failure(); } - TosaSerializationHandler::ConvertF32toU8(data, u8_data); - } else if (type == DType_INT8) { + + if (type == DType_FP16) { + TosaSerializationHandler::ConvertF16toU8(data, u8_data); + } else { + // for all other floating types, store as F32 values + TosaSerializationHandler::ConvertF32toU8(data, u8_data); + } + break; + } + case DType_INT8: { std::vector<int8_t> data; auto val_attr = attr.dyn_cast<mlir::IntegerAttr>(); @@ -347,11 +344,13 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( } else if (val_attr) { data.push_back(val_attr.getInt()); } else { - llvm::errs() << "Unknown const attribute\n"; + op.emitOpError("Unknown const attribute"); return mlir::failure(); } TosaSerializationHandler::ConvertI8toU8(data, u8_data); - } else if (type == DType_INT16) { + break; + } + case DType_INT16: { std::vector<int16_t> data; auto val_attr = attr.dyn_cast<mlir::IntegerAttr>(); @@ -362,11 +361,13 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( } else if (val_attr) { data.push_back(val_attr.getInt()); } else { - llvm::errs() << "Unknown const attribute\n"; + op.emitOpError("Unknown const attribute"); return mlir::failure(); } TosaSerializationHandler::ConvertI16toU8(data, u8_data); - } else if (type == DType_INT32) { + break; + } + case DType_INT32: { std::vector<int32_t> data; auto val_attr = attr.dyn_cast<mlir::IntegerAttr>(); @@ -377,28 +378,32 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( } else if (val_attr) { data.push_back(val_attr.getInt()); } else { - llvm::errs() << "Unknown const attribute\n"; + op.emitOpError("Unknown const attribute"); return mlir::failure(); } TosaSerializationHandler::ConvertI32toU8(data, u8_data); - } else if (type == DType_INT48) { + break; + } + case DType_INT48: { std::vector<int64_t> data; auto val_attr = attr.dyn_cast<mlir::IntegerAttr>(); if (dense_attr) { - for (auto val : dense_attr.getValues<int64_t>()) { + for (auto valueIt : dense_attr.getValues<mlir::APInt>()) { + uint64_t val = valueIt.getLimitedValue(); data.push_back(val); } } else if (val_attr) { data.push_back(val_attr.getInt()); } else { - llvm::errs() << "Unknown const attribute\n"; + op.emitOpError("Unknown const attribute"); return mlir::failure(); } TosaSerializationHandler::ConvertI48toU8(data, u8_data); - } else if (type == DType_BOOL) { + break; + } + case DType_BOOL: { std::vector<bool> data; - auto val_attr = attr.dyn_cast<mlir::BoolAttr>(); if (dense_attr) { @@ -408,15 +413,18 @@ mlir::LogicalResult TosaSerializationOperatorBuilder::GetDataFromAttribute( } else if (val_attr) { data.push_back(val_attr.getValue()); } else { - llvm::errs() << "Unknown const attribute\n"; + op.emitOpError("Unknown const attribute"); return mlir::failure(); } TosaSerializationHandler::ConvertBooltoU8(data, u8_data); - } else { - llvm::errs() << "Unknown element type of const attribute\n"; + break; + } + default: { + op.emitOpError("Unknown element type of const attribute"); return mlir::failure(); } + } return mlir::success(); } @@ -669,19 +677,6 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>( return nullptr; } -#if 0 - // Gracefully handle constants of "constant unit" type which have no value - // by creating a numpy value of 0. - auto unit_val = op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::UnitAttr>(); - - if (unit_val) - { - std::vector<float> data = { 0.0 }; - type = DType_FP32; - TosaSerializationHandler::ConvertF32toU8(data, u8_data); - } -#endif - // Update tensor.data array with Const value attribute mlir::Attribute value_attr = op.getAttr("value"); if (!value_attr) { @@ -689,139 +684,10 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>( return nullptr; } std::vector<uint8_t> u8_data; - + mlir::Attribute attr = op.getAttr(llvm::StringRef("value")); DType type = ts->GetDtype(); - if (type == DType_FP32) { - std::vector<float> data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>(); - - if (dense_attr) { - for (auto val : dense_attr.getValues<mlir::APFloat>()) { - data.push_back(val.convertToFloat()); - } - } else if (val_attr) { - data.push_back((float)val_attr.getValueAsDouble()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertF32toU8(data, u8_data); - } else if (type == DType_FP16) { - std::vector<float> data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>(); - - if (dense_attr) { - for (auto val : dense_attr.getValues<mlir::APFloat>()) { - data.push_back(val.convertToFloat()); - } - } else if (val_attr) { - data.push_back((float)val_attr.getValueAsDouble()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertF16toU8(data, u8_data); - } else if (type == DType_INT8) { - std::vector<int8_t> data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>(); - if (dense_attr) { - for (auto val : dense_attr.getValues<int8_t>()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertI8toU8(data, u8_data); - } else if (type == DType_INT16) { - std::vector<int16_t> data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>(); - - if (dense_attr) { - for (auto val : dense_attr.getValues<int16_t>()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertI16toU8(data, u8_data); - } else if (type == DType_INT32) { - std::vector<int32_t> data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>(); - - if (dense_attr) { - for (auto val : dense_attr.getValues<int32_t>()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertI32toU8(data, u8_data); - } else if (type == DType_INT48) { - std::vector<int64_t> data; - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::IntegerAttr>(); - - if (dense_attr) { - for (auto valueIt : dense_attr.getValues<mlir::APInt>()) { - uint64_t val = valueIt.getLimitedValue(); - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getInt()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - TosaSerializationHandler::ConvertI48toU8(data, u8_data); - } else if (type == DType_BOOL) { - std::vector<bool> data; - - auto dense_attr = op.getAttr(llvm::StringRef("value")) - .dyn_cast<mlir::DenseElementsAttr>(); - auto val_attr = - op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::BoolAttr>(); - - if (dense_attr) { - for (auto val : dense_attr.getValues<bool>()) { - data.push_back(val); - } - } else if (val_attr) { - data.push_back(val_attr.getValue()); - } else { - op.emitOpError("Unknown const attribute"); - return nullptr; - } - - TosaSerializationHandler::ConvertBooltoU8(data, u8_data); - } else { - op.emitOpError("Unknown element type of const attribute"); + if (GetDataFromAttribute(op, attr, type, u8_data).failed()) { return nullptr; } @@ -1898,7 +1764,7 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( if (initial_value) { if (initial_value.isa<mlir::DenseElementsAttr>()) { if (op_builder - .GetDataFromAttribute(initial_value, element_type, u8_data) + .GetDataFromAttribute(*op, initial_value, element_type, u8_data) .failed()) { llvm::errs() << "ERROR: GetDataFromAttribute() fails when building " "initial_value of variable tensor\n"; @@ -1909,14 +1775,6 @@ mlir::LogicalResult TosaSerializationBlockBuilder::BuildAllOpsInBlock( return mlir::failure(); } } else { - uint64_t num_elements = 1; - for (int64_t dim : tensor_type.getShape()) { - num_elements *= dim; - } - uint64_t num_bits = num_elements * GetDTypeSize(element_type); - uint64_t num_bytes = - (num_bits % 8 == 0) ? (num_bits / 8) : (num_bits / 8) + 1; - // std::fill_n(u8_data.begin(), num_bytes, 0); TosaSerializationHandler::ForceAlignTensorData(u8_data); } ser_tensor->SetData(u8_data); |