aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r--reference_model/src/ops/data_layout.cc18
-rw-r--r--reference_model/src/ops/data_nodes.cc4
-rw-r--r--reference_model/src/ops/op_factory.cc51
-rw-r--r--reference_model/src/ops/scatter_gather.cc6
-rw-r--r--reference_model/src/ops/template_types.h24
-rw-r--r--reference_model/src/ops/tensor_ops.cc18
-rw-r--r--reference_model/src/ops/type_conversion.cc175
-rw-r--r--reference_model/src/ops/type_conversion.h278
8 files changed, 569 insertions, 5 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index ec9614a..ddf0713 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -759,6 +759,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
@@ -768,6 +770,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BF16);
@@ -776,6 +780,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2);
DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
DEF_INSTANTIATE_RESHAPE(OpReshape, BF16);
@@ -785,6 +791,8 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
DEF_INSTANTIATE_RESHAPE(OpReshape, FP64);
+DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E4M3);
+DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
@@ -794,6 +802,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
@@ -803,6 +813,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
@@ -812,6 +824,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
@@ -821,6 +835,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
@@ -830,3 +846,5 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2);
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
index 705981c..64001a9 100644
--- a/reference_model/src/ops/data_nodes.cc
+++ b/reference_model/src/ops/data_nodes.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -105,3 +105,5 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E4M3);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E5M2);
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index af8332e..6d66c07 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -55,6 +55,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E5M2);
break;
case Op_AVG_POOL2D:
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16);
@@ -64,6 +66,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP64, FP64);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E4M3, FP16);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16);
break;
case Op_CONV2D:
DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
@@ -74,6 +78,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16);
break;
case Op_CONV3D:
DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
@@ -84,6 +91,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16);
break;
case Op_DEPTHWISE_CONV2D:
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
@@ -94,6 +103,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16);
break;
case Op_FFT2D:
DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32);
@@ -117,6 +128,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32);
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48);
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP64, FP64);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP8E4M3, FP16);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP8E5M2, FP16);
break;
case Op_MAX_POOL2D:
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
@@ -125,6 +138,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP64);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP8E4M3);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP8E5M2);
break;
case Op_RFFT2D:
DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32);
@@ -139,6 +154,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16);
break;
// activation_funcs
@@ -409,6 +426,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2);
break;
case Op_PAD:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
@@ -419,6 +438,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2);
break;
case Op_DIM:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16);
@@ -428,6 +449,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2);
break;
case Op_RESHAPE:
DEF_FACTORY_RESHAPE(OpReshape, FP16);
@@ -438,6 +461,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RESHAPE(OpReshape, INT32);
DEF_FACTORY_RESHAPE(OpReshape, BOOL);
DEF_FACTORY_RESHAPE(OpReshape, FP64);
+ DEF_FACTORY_RESHAPE(OpReshape, FP8E4M3);
+ DEF_FACTORY_RESHAPE(OpReshape, FP8E5M2);
break;
case Op_REVERSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
@@ -448,6 +473,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2);
break;
case Op_SLICE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
@@ -458,6 +485,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2);
break;
case Op_TILE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
@@ -468,6 +497,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2);
break;
case Op_TRANSPOSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
@@ -478,6 +509,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2);
break;
// scatter_gather
@@ -489,6 +522,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpGather, BF16);
DEF_FACTORY_ONE_TYPE(OpGather, FP32);
DEF_FACTORY_ONE_TYPE(OpGather, FP64);
+ DEF_FACTORY_ONE_TYPE(OpGather, FP8E4M3);
+ DEF_FACTORY_ONE_TYPE(OpGather, FP8E5M2);
break;
case Op_SCATTER:
DEF_FACTORY_ONE_TYPE(OpScatter, INT8);
@@ -498,6 +533,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpScatter, BF16);
DEF_FACTORY_ONE_TYPE(OpScatter, FP32);
DEF_FACTORY_ONE_TYPE(OpScatter, FP64);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FP8E4M3);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FP8E5M2);
break;
// image
@@ -524,6 +561,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E5M2);
break;
// type_conversion
@@ -569,6 +608,18 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2);
break;
case Op_RESCALE:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index bd16ad1..85397ae 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -236,6 +236,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpGather, FP64);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, FP8E4M3);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, FP8E5M2);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
@@ -244,3 +246,5 @@ DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP64);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP8E4M3);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP8E5M2);
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index 342d5c2..41e6061 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -88,6 +88,18 @@ struct GetEigenType<TOSA_REF_TYPE_BF16>
using type = float;
};
template <>
+struct GetEigenType<TOSA_REF_TYPE_FP8E4M3>
+{
+ // NOTE: full precision used
+ using type = float;
+};
+template <>
+struct GetEigenType<TOSA_REF_TYPE_FP8E5M2>
+{
+ // NOTE: full precision used
+ using type = float;
+};
+template <>
struct GetEigenType<TOSA_REF_TYPE_INT32>
{
using type = int32_t;
@@ -200,6 +212,16 @@ struct GetNumBits<TOSA_REF_TYPE_FP16>
{
static constexpr int32_t value = 16;
};
+template <>
+struct GetNumBits<TOSA_REF_TYPE_FP8E4M3>
+{
+ static constexpr int32_t value = 8;
+};
+template <>
+struct GetNumBits<TOSA_REF_TYPE_FP8E5M2>
+{
+ static constexpr int32_t value = 8;
+};
// Meta function to get quantized min/max in compile time
template <TOSA_REF_TYPE T>
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index dd66f79..124dc87 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -555,7 +555,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
}
}
if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 &&
- Dtype != TOSA_REF_TYPE_FP64)
+ Dtype != TOSA_REF_TYPE_FP64 && Dtype != TOSA_REF_TYPE_FP8E4M3 && Dtype != TOSA_REF_TYPE_FP8E5M2)
{
try
{
@@ -2155,6 +2155,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E5M2);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
@@ -2163,6 +2165,8 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E4M3, FP16);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E5M2, FP16);
// [in_t, weight_t, out_t]
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
@@ -2173,6 +2177,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
@@ -2182,6 +2188,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
@@ -2191,6 +2199,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
@@ -2211,6 +2221,8 @@ DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP8E4M3, FP16);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP8E5M2, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
@@ -2218,6 +2230,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E4M3);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E5M2);
DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
@@ -2230,3 +2244,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16);
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 484f768..5dbc7bd 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -15,6 +15,7 @@
#include "type_conversion.h"
#include "arith_util.h"
+#include "float_utils.h"
#include "half.hpp"
#include "quant_util.h"
#include "template_types.h"
@@ -24,6 +25,12 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
+using fp16 = tosa::reference::internal::float_t<int16_t, 5, true, true, true>;
+using bf16 = tosa::reference::internal::float_t<int16_t, 8, true, true, true>;
+using fp32 = tosa::reference::internal::float_t<int32_t, 8, true, true, true>;
+using fp8e4m3 = tosa::reference::internal::float_t<int8_t, 4, true, true, false>;
+using fp8e5m2 = tosa::reference::internal::float_t<int8_t, 5, true, true, true>;
+
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_RESCALE, id_)
@@ -527,6 +534,162 @@ CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper()
}
template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP8E4M3, OutDtype>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to integer
+ fcn = [](float in) -> OutEigenType {
+ if (in >= float(OutMax))
+ return OutMax;
+ if (in <= float(OutMin))
+ return OutMin;
+
+ OutEigenType out = std::rint(in);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP16>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to fp16 (stored as fp32)
+ fcn = [](float in) -> float {
+ half_float::half h = half_float::half(in);
+ float out = half_float::half_cast<half_float::half, float>(h);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_BF16>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to bf16 (stored as fp32)
+ fcn = [](float in) -> float { return (float)in; };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP32>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to fp32
+ fcn = [](InEigenType in) -> OutEigenType { return in; };
+}
+
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to integer
+ fcn = [](float in) -> OutEigenType {
+ if (in >= float(OutMax))
+ return OutMax;
+ if (in <= float(OutMin))
+ return OutMin;
+
+ OutEigenType out = std::rint(in);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP16>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to fp16 (stored as fp32)
+ fcn = [](float in) -> float {
+ half_float::half h = half_float::half(in);
+ float out = half_float::half_cast<half_float::half, float>(h);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_BF16>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to bf16 (stored as fp32)
+ fcn = [](float in) -> float { return (float)in; };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP32>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to fp32
+ fcn = [](InEigenType in) -> OutEigenType { return in; };
+}
+
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // Integer data converted to fp8e4m3 (stored as fp32)
+ fcn = [](InEigenType in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e4m3>(float(in)));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // fp16 data (stored as fp32) converted to fp8e4m3 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // bf16 data (stored as fp32) converted to fp8e4m3 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // fp32 data converted to fp8e4m3 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // Integer data converted to fp8e5m2 (stored as fp32)
+ fcn = [](InEigenType in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e5m2>(float(in)));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // fp16 data (stored as fp32) converted to fp8e5m2 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // bf16 data (stored as fp32) converted to fp8e5m2 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // fp32 data converted to fp8e5m2 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+template <TOSA_REF_TYPE OutDtype>
CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper()
{
switch (OutDtype)
@@ -597,6 +760,18 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index 98799a0..75f244d 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -277,6 +277,282 @@ private:
};
template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_FP8E4M3, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP16>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_BF16>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP32>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP16>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_BF16>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP32>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <TOSA_REF_TYPE InDtype>
+class CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E4M3>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E4M3>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E4M3>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <TOSA_REF_TYPE InDtype>
+class CastHelper<InDtype, TOSA_REF_TYPE_FP8E5M2>
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E5M2>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E5M2>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E5M2>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <TOSA_REF_TYPE OutDtype>
class CastHelper<TOSA_REF_TYPE_FP64, OutDtype>
{
public: