diff options
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r-- | src/TosaDeserialize.cpp | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp index 3c7db8e..87c363f 100644 --- a/src/TosaDeserialize.cpp +++ b/src/TosaDeserialize.cpp @@ -138,6 +138,12 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder, case DType_BF16: element_type = op_builder->getBF16Type(); break; + case DType_FP8E4M3: + element_type = op_builder->getFloat8E4M3FNType(); + break; + case DType_FP8E5M2: + element_type = op_builder->getFloat8E5M2Type(); + break; case DType_SHAPE: llvm::errs() << "ERROR: Cannot construct RankedTensorType out of tosa.shape type \n"; @@ -172,7 +178,11 @@ ConstructConstAttr(const mlir::RankedTensorType &output_type, } mlir::DenseElementsAttr value_attr; switch (ts->GetDtype()) { - case DType_FP32: { + case DType_FP32: + case DType_BF16: + case DType_FP8E4M3: + case DType_FP8E5M2: { + // for FP32, FP16 and FP8 types, value attributes are stored as FP32 values std::vector<float> float_data; TosaSerializationHandler::ConvertU8toF32(data, out_size, float_data); value_attr = @@ -236,7 +246,6 @@ ConstructConstAttr(const mlir::RankedTensorType &output_type, } case DType_UINT8: case DType_UINT16: - case DType_BF16: default: { llvm::errs() << "ERROR: " << op_name << " contains unsupported element type\n"; |