diff options
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 18 | ||||
-rw-r--r-- | reference_model/src/ops/data_nodes.cc | 4 | ||||
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 51 | ||||
-rw-r--r-- | reference_model/src/ops/scatter_gather.cc | 6 | ||||
-rw-r--r-- | reference_model/src/ops/template_types.h | 24 | ||||
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 18 | ||||
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 175 | ||||
-rw-r--r-- | reference_model/src/ops/type_conversion.h | 278 |
8 files changed, 569 insertions, 5 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index ec9614a..ddf0713 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -759,6 +759,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16); @@ -768,6 +770,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BF16); @@ -776,6 +780,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2); DEF_INSTANTIATE_RESHAPE(OpReshape, FP16); DEF_INSTANTIATE_RESHAPE(OpReshape, BF16); @@ -785,6 +791,8 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); DEF_INSTANTIATE_RESHAPE(OpReshape, FP64); +DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E4M3); +DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16); @@ -794,6 +802,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); @@ -803,6 +813,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16); @@ -812,6 +824,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16); @@ -821,6 +835,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); @@ -830,3 +846,5 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2); diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index 705981c..64001a9 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -105,3 +105,5 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E5M2); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index af8332e..6d66c07 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -55,6 +55,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E5M2); break; case Op_AVG_POOL2D: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); @@ -64,6 +66,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP64, FP64); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E4M3, FP16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16); break; case Op_CONV2D: DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16); @@ -74,6 +78,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48); DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); break; case Op_CONV3D: DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16); @@ -84,6 +91,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48); DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); break; case Op_DEPTHWISE_CONV2D: DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); @@ -94,6 +103,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); break; case Op_FFT2D: DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32); @@ -117,6 +128,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32); DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48); DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP64, FP64); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP8E4M3, FP16); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP8E5M2, FP16); break; case Op_MAX_POOL2D: DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); @@ -125,6 +138,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP64); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP8E4M3); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP8E5M2); break; case Op_RFFT2D: DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32); @@ -139,6 +154,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); break; // activation_funcs @@ -409,6 +426,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2); break; case Op_PAD: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); @@ -419,6 +438,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2); break; case Op_DIM: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16); @@ -428,6 +449,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2); break; case Op_RESHAPE: DEF_FACTORY_RESHAPE(OpReshape, FP16); @@ -438,6 +461,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RESHAPE(OpReshape, INT32); DEF_FACTORY_RESHAPE(OpReshape, BOOL); DEF_FACTORY_RESHAPE(OpReshape, FP64); + DEF_FACTORY_RESHAPE(OpReshape, FP8E4M3); + DEF_FACTORY_RESHAPE(OpReshape, FP8E5M2); break; case Op_REVERSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); @@ -448,6 +473,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2); break; case Op_SLICE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); @@ -458,6 +485,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2); break; case Op_TILE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); @@ -468,6 +497,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2); break; case Op_TRANSPOSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); @@ -478,6 +509,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2); break; // scatter_gather @@ -489,6 +522,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, BF16); DEF_FACTORY_ONE_TYPE(OpGather, FP32); DEF_FACTORY_ONE_TYPE(OpGather, FP64); + DEF_FACTORY_ONE_TYPE(OpGather, FP8E4M3); + DEF_FACTORY_ONE_TYPE(OpGather, FP8E5M2); break; case Op_SCATTER: DEF_FACTORY_ONE_TYPE(OpScatter, INT8); @@ -498,6 +533,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpScatter, BF16); DEF_FACTORY_ONE_TYPE(OpScatter, FP32); DEF_FACTORY_ONE_TYPE(OpScatter, FP64); + DEF_FACTORY_ONE_TYPE(OpScatter, FP8E4M3); + DEF_FACTORY_ONE_TYPE(OpScatter, FP8E5M2); break; // image @@ -524,6 +561,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E5M2); break; // type_conversion @@ -569,6 +608,18 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2); break; case Op_RESCALE: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index bd16ad1..85397ae 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -236,6 +236,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16); DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP64); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP8E4M3); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP8E5M2); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); @@ -244,3 +246,5 @@ DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP64); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP8E4M3); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP8E5M2); diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 342d5c2..41e6061 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -88,6 +88,18 @@ struct GetEigenType<TOSA_REF_TYPE_BF16> using type = float; }; template <> +struct GetEigenType<TOSA_REF_TYPE_FP8E4M3> +{ + // NOTE: full precision used + using type = float; +}; +template <> +struct GetEigenType<TOSA_REF_TYPE_FP8E5M2> +{ + // NOTE: full precision used + using type = float; +}; +template <> struct GetEigenType<TOSA_REF_TYPE_INT32> { using type = int32_t; @@ -200,6 +212,16 @@ struct GetNumBits<TOSA_REF_TYPE_FP16> { static constexpr int32_t value = 16; }; +template <> +struct GetNumBits<TOSA_REF_TYPE_FP8E4M3> +{ + static constexpr int32_t value = 8; +}; +template <> +struct GetNumBits<TOSA_REF_TYPE_FP8E5M2> +{ + static constexpr int32_t value = 8; +}; // Meta function to get quantized min/max in compile time template <TOSA_REF_TYPE T> diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index dd66f79..124dc87 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -555,7 +555,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval() } } if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 && - Dtype != TOSA_REF_TYPE_FP64) + Dtype != TOSA_REF_TYPE_FP64 && Dtype != TOSA_REF_TYPE_FP8E4M3 && Dtype != TOSA_REF_TYPE_FP8E5M2) { try { @@ -2155,6 +2155,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E5M2); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32); @@ -2163,6 +2165,8 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E4M3, FP16); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E5M2, FP16); // [in_t, weight_t, out_t] DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16); @@ -2173,6 +2177,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32); @@ -2182,6 +2188,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); @@ -2191,6 +2199,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64); @@ -2211,6 +2221,8 @@ DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP8E4M3, FP16); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP8E5M2, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16); @@ -2218,6 +2230,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E4M3); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E5M2); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64); @@ -2230,3 +2244,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 484f768..5dbc7bd 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -15,6 +15,7 @@ #include "type_conversion.h" #include "arith_util.h" +#include "float_utils.h" #include "half.hpp" #include "quant_util.h" #include "template_types.h" @@ -24,6 +25,12 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; +using fp16 = tosa::reference::internal::float_t<int16_t, 5, true, true, true>; +using bf16 = tosa::reference::internal::float_t<int16_t, 8, true, true, true>; +using fp32 = tosa::reference::internal::float_t<int32_t, 8, true, true, true>; +using fp8e4m3 = tosa::reference::internal::float_t<int8_t, 4, true, true, false>; +using fp8e5m2 = tosa::reference::internal::float_t<int8_t, 5, true, true, true>; + template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_RESCALE, id_) @@ -527,6 +534,162 @@ CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper() } template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_FP8E4M3, OutDtype>::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to integer + fcn = [](float in) -> OutEigenType { + if (in >= float(OutMax)) + return OutMax; + if (in <= float(OutMin)) + return OutMin; + + OutEigenType out = std::rint(in); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP16>::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to fp16 (stored as fp32) + fcn = [](float in) -> float { + half_float::half h = half_float::half(in); + float out = half_float::half_cast<half_float::half, float>(h); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_BF16>::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to bf16 (stored as fp32) + fcn = [](float in) -> float { return (float)in; }; +} + +CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP32>::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to fp32 + fcn = [](InEigenType in) -> OutEigenType { return in; }; +} + +template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype>::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to integer + fcn = [](float in) -> OutEigenType { + if (in >= float(OutMax)) + return OutMax; + if (in <= float(OutMin)) + return OutMin; + + OutEigenType out = std::rint(in); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP16>::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to fp16 (stored as fp32) + fcn = [](float in) -> float { + half_float::half h = half_float::half(in); + float out = half_float::half_cast<half_float::half, float>(h); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_BF16>::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to bf16 (stored as fp32) + fcn = [](float in) -> float { return (float)in; }; +} + +CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP32>::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to fp32 + fcn = [](InEigenType in) -> OutEigenType { return in; }; +} + +template <TOSA_REF_TYPE InDtype> +CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>::CastHelper() +{ + // Integer data converted to fp8e4m3 (stored as fp32) + fcn = [](InEigenType in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e4m3>(float(in))); + float out = static_cast<float>(f); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E4M3>::CastHelper() +{ + // fp16 data (stored as fp32) converted to fp8e4m3 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e4m3>(in)); + float out = static_cast<float>(f); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E4M3>::CastHelper() +{ + // bf16 data (stored as fp32) converted to fp8e4m3 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e4m3>(in)); + float out = static_cast<float>(f); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E4M3>::CastHelper() +{ + // fp32 data converted to fp8e4m3 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e4m3>(in)); + float out = static_cast<float>(f); + return out; + }; +} + +template <TOSA_REF_TYPE InDtype> +CastHelper<InDtype, TOSA_REF_TYPE_FP8E5M2>::CastHelper() +{ + // Integer data converted to fp8e5m2 (stored as fp32) + fcn = [](InEigenType in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e5m2>(float(in))); + float out = static_cast<float>(f); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E5M2>::CastHelper() +{ + // fp16 data (stored as fp32) converted to fp8e5m2 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e5m2>(in)); + float out = static_cast<float>(f); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E5M2>::CastHelper() +{ + // bf16 data (stored as fp32) converted to fp8e5m2 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e5m2>(in)); + float out = static_cast<float>(f); + return out; + }; +} + +CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E5M2>::CastHelper() +{ + // fp32 data converted to fp8e5m2 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast<fp32>(static_cast<fp8e5m2>(in)); + float out = static_cast<float>(f); + return out; + }; +} + +template <TOSA_REF_TYPE OutDtype> CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper() { switch (OutDtype) @@ -597,6 +760,18 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index 98799a0..75f244d 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -277,6 +277,282 @@ private: }; template <TOSA_REF_TYPE OutDtype> +class CastHelper<TOSA_REF_TYPE_FP8E4M3, OutDtype> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + static constexpr int32_t OutMin = GetQMin<OutDtype>::value; + static constexpr int32_t OutMax = GetQMax<OutDtype>::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP16> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_BF16> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP32> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <TOSA_REF_TYPE OutDtype> +class CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + static constexpr int32_t OutMin = GetQMin<OutDtype>::value; + static constexpr int32_t OutMax = GetQMax<OutDtype>::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP16> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_BF16> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP32> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <TOSA_REF_TYPE InDtype> +class CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3> +{ +public: + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E4M3> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E4M3> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E4M3> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <TOSA_REF_TYPE InDtype> +class CastHelper<InDtype, TOSA_REF_TYPE_FP8E5M2> +{ +public: + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E5M2> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E5M2> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E5M2> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <TOSA_REF_TYPE OutDtype> class CastHelper<TOSA_REF_TYPE_FP64, OutDtype> { public: |