aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/op_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r--reference_model/src/ops/op_factory.cc51
1 files changed, 51 insertions, 0 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index af8332e..6d66c07 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -55,6 +55,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E5M2);
break;
case Op_AVG_POOL2D:
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16);
@@ -64,6 +66,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP64, FP64);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E4M3, FP16);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16);
break;
case Op_CONV2D:
DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
@@ -74,6 +78,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16);
break;
case Op_CONV3D:
DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
@@ -84,6 +91,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16);
break;
case Op_DEPTHWISE_CONV2D:
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
@@ -94,6 +103,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16);
break;
case Op_FFT2D:
DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32);
@@ -117,6 +128,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32);
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48);
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP64, FP64);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP8E4M3, FP16);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP8E5M2, FP16);
break;
case Op_MAX_POOL2D:
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
@@ -125,6 +138,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP64);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP8E4M3);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP8E5M2);
break;
case Op_RFFT2D:
DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32);
@@ -139,6 +154,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16);
break;
// activation_funcs
@@ -409,6 +426,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2);
break;
case Op_PAD:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
@@ -419,6 +438,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2);
break;
case Op_DIM:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16);
@@ -428,6 +449,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2);
break;
case Op_RESHAPE:
DEF_FACTORY_RESHAPE(OpReshape, FP16);
@@ -438,6 +461,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RESHAPE(OpReshape, INT32);
DEF_FACTORY_RESHAPE(OpReshape, BOOL);
DEF_FACTORY_RESHAPE(OpReshape, FP64);
+ DEF_FACTORY_RESHAPE(OpReshape, FP8E4M3);
+ DEF_FACTORY_RESHAPE(OpReshape, FP8E5M2);
break;
case Op_REVERSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
@@ -448,6 +473,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2);
break;
case Op_SLICE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
@@ -458,6 +485,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2);
break;
case Op_TILE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
@@ -468,6 +497,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2);
break;
case Op_TRANSPOSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
@@ -478,6 +509,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2);
break;
// scatter_gather
@@ -489,6 +522,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpGather, BF16);
DEF_FACTORY_ONE_TYPE(OpGather, FP32);
DEF_FACTORY_ONE_TYPE(OpGather, FP64);
+ DEF_FACTORY_ONE_TYPE(OpGather, FP8E4M3);
+ DEF_FACTORY_ONE_TYPE(OpGather, FP8E5M2);
break;
case Op_SCATTER:
DEF_FACTORY_ONE_TYPE(OpScatter, INT8);
@@ -498,6 +533,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpScatter, BF16);
DEF_FACTORY_ONE_TYPE(OpScatter, FP32);
DEF_FACTORY_ONE_TYPE(OpScatter, FP64);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FP8E4M3);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FP8E5M2);
break;
// image
@@ -524,6 +561,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E5M2);
break;
// type_conversion
@@ -569,6 +608,18 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2);
break;
case Op_RESCALE:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);