aboutsummaryrefslogtreecommitdiff
path: root/src/TosaDeserialize.cpp
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-12-01 22:58:55 +0000
committerTai Ly <tai.ly@arm.com>2023-12-02 03:56:41 +0000
commit20f6941b21f84cd5f0152d42f343b0992dd5a6e5 (patch)
tree6b9e191c32b60077206fcbf5c73c889eba681729 /src/TosaDeserialize.cpp
parentfc32f56a067c526238c15de097fe78fdcab95cb5 (diff)
downloadtosa_mlir_translator-20f6941b21f84cd5f0152d42f343b0992dd5a6e5.tar.gz
[tosa_mlir_translator] Add FP16 support
serialize/deserialize FP16 tensors and constants Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Iab75aeda45983f328796f9463a57c69e86ab8f3e
Diffstat (limited to 'src/TosaDeserialize.cpp')
-rw-r--r--src/TosaDeserialize.cpp17
1 files changed, 17 insertions, 0 deletions
diff --git a/src/TosaDeserialize.cpp b/src/TosaDeserialize.cpp
index 8799028..f1b7d98 100644
--- a/src/TosaDeserialize.cpp
+++ b/src/TosaDeserialize.cpp
@@ -132,6 +132,12 @@ mlir::LogicalResult BuildTensorType(mlir::OpBuilder *op_builder,
case DType_UINT16:
element_type = op_builder->getIntegerType(16, false);
break;
+ case DType_FP16:
+ element_type = op_builder->getF16Type();
+ break;
+ case DType_BF16:
+ element_type = op_builder->getBF16Type();
+ break;
case DType_SHAPE:
element_type = op_builder->getIntegerType(64);
break;
@@ -220,6 +226,17 @@ ConstructConstAttr(const mlir::RankedTensorType &output_type,
value_attr = mlir::DenseElementsAttr::get(output_type, bool_values);
break;
}
+ case DType_FP16: {
+ std::vector<float> float_data;
+ TosaSerializationHandler::ConvertU8toF16(data, out_size, float_data);
+ value_attr =
+ mlir::DenseElementsAttr::get(output_type, llvm::ArrayRef(float_data));
+ break;
+ }
+ case DType_UINT8:
+ case DType_UINT16:
+ case DType_BF16:
+ case DType_SHAPE:
default: {
llvm::errs() << "ERROR: " << op_name
<< " contains unsupported element type\n";