From 20f6941b21f84cd5f0152d42f343b0992dd5a6e5 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Fri, 1 Dec 2023 22:58:55 +0000 Subject: [tosa_mlir_translator] Add FP16 support serialize/deserialize FP16 tensors and constants Signed-off-by: Tai Ly Change-Id: Iab75aeda45983f328796f9463a57c69e86ab8f3e --- src/TosaDeserialize.cpp | 17 +++++++++++++++++ src/TosaSerialize.cpp | 25 +++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 8799028..f1b7d98 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -132,6 +132,12 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, case DType_UINT16: element_type = op_builder->getIntegerType(16, false); break; + case DType_FP16: + element_type = op_builder->getF16Type(); + break; + case DType_BF16: + element_type = op_builder->getBF16Type(); + break; case DType_SHAPE: element_type = op_builder->getIntegerType(64); break; @@ -220,6 +226,17 @@ ConstructConstAttr(const mlir::RankedTensorType &output_type, value_attr = mlir::DenseElementsAttr::get(output_type, bool_values); break; } + case DType_FP16: { + std::vector float_data; + TosaSerializationHandler::ConvertU8toF16(data, out_size, float_data); + value_attr = + mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(float_data)); + break; + } + case DType_UINT8: + case DType_UINT16: + case DType_BF16: + case DType_SHAPE: default: { llvm::errs() << "ERROR: " << op_name << " contains unsupported element type\n"; diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 263f51c..2d038f0 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -75,9 +75,12 @@ static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { } static DType Type2DType(mlir::Type element_type) { - if (element_type.isF64() || element_type.isF32() || element_type.isF16() || - element_type.isBF16()) { + if (element_type.isF64() || element_type.isF32()) { return DType_FP32; + } else if (element_type.isF16()) { + return DType_FP16; + } else if (element_type.isBF16()) { + return DType_BF16; } else if (element_type.isUnsignedInteger(8)) { return DType_UINT8; } else if (element_type.isInteger(4)) { @@ -658,6 +661,24 @@ TosaSerializationOperatorBuilder::build( return nullptr; } TosaSerializationHandler::ConvertF32toU8(data, u8_data); + } else if (type == DType_FP16) { + std::vector data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast(); + + if (dense_attr) { + for (auto val : dense_attr.getValues()) { + data.push_back(val.convertToFloat()); + } + } else if (val_attr) { + data.push_back((float)val_attr.getValueAsDouble()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + TosaSerializationHandler::ConvertF16toU8(data, u8_data); } else if (type == DType_INT8) { std::vector data; auto dense_attr = op.getAttr(llvm::StringRef("value")) -- cgit v1.2.1