aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/op_factory.cc
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-03-28 22:06:56 +0000
committerTai Ly <tai.ly@arm.com>2023-05-05 19:23:15 +0000
commita4d748b08accce06fab93e2d2b96e499b35ae89b (patch)
tree20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/src/ops/op_factory.cc
parent0c71686875618b2e11290273b7a05b88ef8a8aae (diff)
downloadreference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r--reference_model/src/ops/op_factory.cc60
1 files changed, 57 insertions, 3 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 1db3974..0a78884 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -37,11 +37,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
Op opType,
TosaAttributeBase* attribute,
uint64_t id,
- DType inputDType,
+ TOSA_REF_TYPE inputDTYPE,
int inputRank,
- DType outputDType,
+ TOSA_REF_TYPE outputDTYPE,
int outputRank,
- DType weightDType,
+ TOSA_REF_TYPE weightDTYPE,
int weightRank)
{
switch (opType)
@@ -53,6 +53,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
break;
case Op_AVG_POOL2D:
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16);
@@ -61,6 +62,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP64, FP64);
break;
case Op_CONV2D:
DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
@@ -70,6 +72,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
break;
case Op_CONV3D:
DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
@@ -79,6 +82,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
break;
case Op_DEPTHWISE_CONV2D:
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
@@ -88,9 +92,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
break;
case Op_FFT2D:
DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32);
+ DEF_FACTORY_ONE_TYPE(OpFFT2d, FP64);
break;
case Op_FULLY_CONNECTED:
DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
@@ -100,6 +106,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
break;
case Op_MATMUL:
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP16);
@@ -108,6 +115,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP32, FP32);
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32);
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP64, FP64);
break;
case Op_MAX_POOL2D:
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
@@ -115,9 +123,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP64);
break;
case Op_RFFT2D:
DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32);
+ DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64);
break;
case Op_TRANSPOSE_CONV2D:
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
@@ -127,6 +137,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
break;
// activation_funcs
@@ -136,16 +147,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP64);
break;
case Op_SIGMOID:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP64);
break;
case Op_TANH:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64);
break;
// ewise_binary
@@ -154,6 +168,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP64);
break;
case Op_ARITHMETIC_RIGHT_SHIFT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
@@ -202,12 +217,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP64);
break;
case Op_MINIMUM:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP64);
break;
case Op_MUL:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
@@ -216,17 +233,20 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP64, FP64);
break;
case Op_POW:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP64);
break;
case Op_SUB:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP64);
break;
case Op_TABLE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
@@ -239,6 +259,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64);
break;
case Op_BITWISE_NOT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8);
@@ -249,6 +270,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64);
break;
case Op_CLZ:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
@@ -257,16 +279,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64);
break;
case Op_FLOOR:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64);
break;
case Op_LOG:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64);
break;
case Op_LOGICAL_NOT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
@@ -278,16 +303,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64);
break;
case Op_RECIPROCAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64);
break;
case Op_RSQRT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64);
break;
// ewise_ternary
@@ -299,6 +327,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP64);
break;
// comparison
@@ -307,18 +336,21 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP64);
break;
case Op_GREATER:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP64);
break;
case Op_GREATER_EQUAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP64);
break;
// reduction
@@ -335,6 +367,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64);
break;
case Op_REDUCE_MIN:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
@@ -343,16 +376,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64);
break;
case Op_REDUCE_PRODUCT:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64);
break;
case Op_REDUCE_SUM:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
break;
@@ -365,6 +401,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64);
break;
case Op_PAD:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
@@ -374,6 +411,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
break;
case Op_RESHAPE:
DEF_FACTORY_RESHAPE(OpReshape, FP16);
@@ -383,6 +421,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RESHAPE(OpReshape, INT16);
DEF_FACTORY_RESHAPE(OpReshape, INT32);
DEF_FACTORY_RESHAPE(OpReshape, BOOL);
+ DEF_FACTORY_RESHAPE(OpReshape, FP64);
break;
case Op_REVERSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
@@ -392,6 +431,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
break;
case Op_SLICE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
@@ -401,6 +441,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
break;
case Op_TILE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
@@ -410,6 +451,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
break;
case Op_TRANSPOSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
@@ -419,6 +461,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);
break;
// scatter_gather
@@ -429,6 +472,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpGather, FP16);
DEF_FACTORY_ONE_TYPE(OpGather, BF16);
DEF_FACTORY_ONE_TYPE(OpGather, FP32);
+ DEF_FACTORY_ONE_TYPE(OpGather, FP64);
break;
case Op_SCATTER:
DEF_FACTORY_ONE_TYPE(OpScatter, INT8);
@@ -437,6 +481,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpScatter, FP16);
DEF_FACTORY_ONE_TYPE(OpScatter, BF16);
DEF_FACTORY_ONE_TYPE(OpScatter, FP32);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FP64);
break;
// image
@@ -448,6 +493,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16);
DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OpResize, BF16, BF16);
DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32);
+ DEF_FACTORY_TWO_TYPE_RESIZE_FP64(OpResize, FP64, FP64);
break;
// data_nodes
@@ -461,6 +507,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64);
break;
// type_conversion
@@ -499,6 +546,13 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
break;
case Op_RESCALE:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);