diff options
author | Tai Ly <tai.ly@arm.com> | 2023-12-01 22:58:55 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-12-02 03:56:41 +0000 |
commit | 20f6941b21f84cd5f0152d42f343b0992dd5a6e5 (patch) | |
tree | 6b9e191c32b60077206fcbf5c73c889eba681729 /src/TosaSerialize.cpp | |
parent | fc32f56a067c526238c15de097fe78fdcab95cb5 (diff) | |
download | tosa_mlir_translator-20f6941b21f84cd5f0152d42f343b0992dd5a6e5.tar.gz |
[tosa_mlir_translator] Add FP16 support
serialize/deserialize FP16 tensors and constants
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: Iab75aeda45983f328796f9463a57c69e86ab8f3e
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r-- | src/TosaSerialize.cpp | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp index 263f51c..2d038f0 100644 --- a/src/TosaSerialize.cpp +++ b/src/TosaSerialize.cpp @@ -75,9 +75,12 @@ static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) { } static DType Type2DType(mlir::Type element_type) { - if (element_type.isF64() || element_type.isF32() || element_type.isF16() || - element_type.isBF16()) { + if (element_type.isF64() || element_type.isF32()) { return DType_FP32; + } else if (element_type.isF16()) { + return DType_FP16; + } else if (element_type.isBF16()) { + return DType_BF16; } else if (element_type.isUnsignedInteger(8)) { return DType_UINT8; } else if (element_type.isInteger(4)) { @@ -658,6 +661,24 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>( return nullptr; } TosaSerializationHandler::ConvertF32toU8(data, u8_data); + } else if (type == DType_FP16) { + std::vector<float> data; + auto dense_attr = op.getAttr(llvm::StringRef("value")) + .dyn_cast<mlir::DenseElementsAttr>(); + auto val_attr = + op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>(); + + if (dense_attr) { + for (auto val : dense_attr.getValues<mlir::APFloat>()) { + data.push_back(val.convertToFloat()); + } + } else if (val_attr) { + data.push_back((float)val_attr.getValueAsDouble()); + } else { + op.emitOpError("Unknown const attribute"); + return nullptr; + } + TosaSerializationHandler::ConvertF16toU8(data, u8_data); } else if (type == DType_INT8) { std::vector<int8_t> data; auto dense_attr = op.getAttr(llvm::StringRef("value")) |