aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-12-01 22:58:55 +0000
committerTai Ly <tai.ly@arm.com>2023-12-02 03:56:41 +0000
commit20f6941b21f84cd5f0152d42f343b0992dd5a6e5 (patch)
tree6b9e191c32b60077206fcbf5c73c889eba681729 /src/TosaSerialize.cpp
parentfc32f56a067c526238c15de097fe78fdcab95cb5 (diff)
downloadtosa_mlir_translator-20f6941b21f84cd5f0152d42f343b0992dd5a6e5.tar.gz
[tosa_mlir_translator] Add FP16 support
serialize/deserialize FP16 tensors and constants Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Iab75aeda45983f328796f9463a57c69e86ab8f3e
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp25
1 files changed, 23 insertions, 2 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 263f51c..2d038f0 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -75,9 +75,12 @@ static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
}
static DType Type2DType(mlir::Type element_type) {
- if (element_type.isF64() || element_type.isF32() || element_type.isF16() ||
- element_type.isBF16()) {
+ if (element_type.isF64() || element_type.isF32()) {
return DType_FP32;
+ } else if (element_type.isF16()) {
+ return DType_FP16;
+ } else if (element_type.isBF16()) {
+ return DType_BF16;
} else if (element_type.isUnsignedInteger(8)) {
return DType_UINT8;
} else if (element_type.isInteger(4)) {
@@ -658,6 +661,24 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
return nullptr;
}
TosaSerializationHandler::ConvertF32toU8(data, u8_data);
+ } else if (type == DType_FP16) {
+ std::vector<float> data;
+ auto dense_attr = op.getAttr(llvm::StringRef("value"))
+ .dyn_cast<mlir::DenseElementsAttr>();
+ auto val_attr =
+ op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<mlir::APFloat>()) {
+ data.push_back(val.convertToFloat());
+ }
+ } else if (val_attr) {
+ data.push_back((float)val_attr.getValueAsDouble());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
+ }
+ TosaSerializationHandler::ConvertF16toU8(data, u8_data);
} else if (type == DType_INT8) {
std::vector<int8_t> data;
auto dense_attr = op.getAttr(llvm::StringRef("value"))