aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp13
1 files changed, 11 insertions, 2 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 3c7db8e..87c363f 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -138,6 +138,12 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
case DType_BF16:
element_type = op_builder->getBF16Type();
break;
+ case DType_FP8E4M3:
+ element_type = op_builder->getFloat8E4M3FNType();
+ break;
+ case DType_FP8E5M2:
+ element_type = op_builder->getFloat8E5M2Type();
+ break;
case DType_SHAPE:
llvm::errs()
<< "ERROR: Cannot construct RankedTensorType out of tosa.shape type \n";
@@ -172,7 +178,11 @@ ConstructConstAttr(const mlir::RankedTensorType &output_type,
}
mlir::DenseElementsAttr value_attr;
switch (ts->GetDtype()) {
- case DType_FP32: {
+ case DType_FP32:
+ case DType_BF16:
+ case DType_FP8E4M3:
+ case DType_FP8E5M2: {
+ // for FP32, FP16 and FP8 types, value attributes are stored as FP32 values
std::vector<float> float_data;
TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data);
value_attr =
@@ -236,7 +246,6 @@ ConstructConstAttr(const mlir::RankedTensorType &output_type,
}
case DType_UINT8:
case DType_UINT16:
- case DType_BF16:
default: {
llvm::errs() << "ERROR: " << op_name
<< " contains unsupported element type\n";