aboutsummaryrefslogtreecommitdiff
path: root/src/TosaSerialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaSerialize.cpp')
-rw-r--r--src/TosaSerialize.cpp25
1 files changed, 23 insertions, 2 deletions
diff --git a/src/TosaSerialize.cpp b/src/TosaSerialize.cpp
index 263f51c..2d038f0 100644
--- a/src/TosaSerialize.cpp
+++ b/src/TosaSerialize.cpp
@@ -75,9 +75,12 @@ static ResizeMode ResizeModeStr2Enum(const std::string &mode_str) {
}
static DType Type2DType(mlir::Type element_type) {
- if (element_type.isF64() || element_type.isF32() || element_type.isF16() ||
- element_type.isBF16()) {
+ if (element_type.isF64() || element_type.isF32()) {
return DType_FP32;
+ } else if (element_type.isF16()) {
+ return DType_FP16;
+ } else if (element_type.isBF16()) {
+ return DType_BF16;
} else if (element_type.isUnsignedInteger(8)) {
return DType_UINT8;
} else if (element_type.isInteger(4)) {
@@ -658,6 +661,24 @@ TosaSerializationOperatorBuilder::build<mlir::tosa::ConstOp>(
return nullptr;
}
TosaSerializationHandler::ConvertF32toU8(data, u8_data);
+ } else if (type == DType_FP16) {
+ std::vector<float> data;
+ auto dense_attr = op.getAttr(llvm::StringRef("value"))
+ .dyn_cast<mlir::DenseElementsAttr>();
+ auto val_attr =
+ op.getAttr(llvm::StringRef("value")).dyn_cast<mlir::FloatAttr>();
+
+ if (dense_attr) {
+ for (auto val : dense_attr.getValues<mlir::APFloat>()) {
+ data.push_back(val.convertToFloat());
+ }
+ } else if (val_attr) {
+ data.push_back((float)val_attr.getValueAsDouble());
+ } else {
+ op.emitOpError("Unknown const attribute");
+ return nullptr;
+ }
+ TosaSerializationHandler::ConvertF16toU8(data, u8_data);
} else if (type == DType_INT8) {
std::vector<int8_t> data;
auto dense_attr = op.getAttr(llvm::StringRef("value"))