diff options
author | Tai Ly <tai.ly@arm.com> | 2023-03-28 22:06:56 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-05-05 19:23:15 +0000 |
commit | a4d748b08accce06fab93e2d2b96e499b35ae89b (patch) | |
tree | 20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/src/ops/op_factory.cc | |
parent | 0c71686875618b2e11290273b7a05b88ef8a8aae (diff) | |
download | reference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz |
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model,
which will cause reference model to convert all floating point tensors
to FP64 tensors and compute all operators accordingly.
Also adds optional -p arguments to test runners tosa_verif_run_tests.py
and tosa_verif_framework_compiler_runner.py to run tests in precise mode
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 60 |
1 files changed, 57 insertions, 3 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 1db3974..0a78884 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -37,11 +37,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, Op opType, TosaAttributeBase* attribute, uint64_t id, - DType inputDType, + TOSA_REF_TYPE inputDTYPE, int inputRank, - DType outputDType, + TOSA_REF_TYPE outputDTYPE, int outputRank, - DType weightDType, + TOSA_REF_TYPE weightDTYPE, int weightRank) { switch (opType) @@ -53,6 +53,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64); break; case Op_AVG_POOL2D: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); @@ -61,6 +62,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP64, FP64); break; case Op_CONV2D: DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16); @@ -70,6 +72,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); break; case Op_CONV3D: DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16); @@ -79,6 +82,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64); break; case Op_DEPTHWISE_CONV2D: DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); @@ -88,9 +92,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); break; case Op_FFT2D: DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32); + DEF_FACTORY_ONE_TYPE(OpFFT2d, FP64); break; case Op_FULLY_CONNECTED: DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); @@ -100,6 +106,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64); break; case Op_MATMUL: DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP16); @@ -108,6 +115,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP32, FP32); DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32); DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP64, FP64); break; case Op_MAX_POOL2D: DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); @@ -115,9 +123,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP64); break; case Op_RFFT2D: DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32); + DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64); break; case Op_TRANSPOSE_CONV2D: DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); @@ -127,6 +137,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); break; // activation_funcs @@ -136,16 +147,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP64); break; case Op_SIGMOID: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP64); break; case Op_TANH: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64); break; // ewise_binary @@ -154,6 +168,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP64); break; case Op_ARITHMETIC_RIGHT_SHIFT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); @@ -202,12 +217,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP64); break; case Op_MINIMUM: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP64); break; case Op_MUL: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); @@ -216,17 +233,20 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP64, FP64); break; case Op_POW: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP64); break; case Op_SUB: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP64); break; case Op_TABLE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); @@ -239,6 +259,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64); break; case Op_BITWISE_NOT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8); @@ -249,6 +270,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64); break; case Op_CLZ: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); @@ -257,16 +279,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64); break; case Op_FLOOR: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64); break; case Op_LOG: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64); break; case Op_LOGICAL_NOT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); @@ -278,16 +303,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64); break; case Op_RECIPROCAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64); break; case Op_RSQRT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64); break; // ewise_ternary @@ -299,6 +327,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP64); break; // comparison @@ -307,18 +336,21 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP64); break; case Op_GREATER: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP64); break; case Op_GREATER_EQUAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP64); break; // reduction @@ -335,6 +367,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64); break; case Op_REDUCE_MIN: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); @@ -343,16 +376,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64); break; case Op_REDUCE_PRODUCT: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64); break; case Op_REDUCE_SUM: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); break; @@ -365,6 +401,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64); break; case Op_PAD: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); @@ -374,6 +411,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64); break; case Op_RESHAPE: DEF_FACTORY_RESHAPE(OpReshape, FP16); @@ -383,6 +421,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RESHAPE(OpReshape, INT16); DEF_FACTORY_RESHAPE(OpReshape, INT32); DEF_FACTORY_RESHAPE(OpReshape, BOOL); + DEF_FACTORY_RESHAPE(OpReshape, FP64); break; case Op_REVERSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); @@ -392,6 +431,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64); break; case Op_SLICE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); @@ -401,6 +441,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64); break; case Op_TILE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); @@ -410,6 +451,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64); break; case Op_TRANSPOSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); @@ -419,6 +461,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64); break; // scatter_gather @@ -429,6 +472,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, FP16); DEF_FACTORY_ONE_TYPE(OpGather, BF16); DEF_FACTORY_ONE_TYPE(OpGather, FP32); + DEF_FACTORY_ONE_TYPE(OpGather, FP64); break; case Op_SCATTER: DEF_FACTORY_ONE_TYPE(OpScatter, INT8); @@ -437,6 +481,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpScatter, FP16); DEF_FACTORY_ONE_TYPE(OpScatter, BF16); DEF_FACTORY_ONE_TYPE(OpScatter, FP32); + DEF_FACTORY_ONE_TYPE(OpScatter, FP64); break; // image @@ -448,6 +493,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16); DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OpResize, BF16, BF16); DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32); + DEF_FACTORY_TWO_TYPE_RESIZE_FP64(OpResize, FP64, FP64); break; // data_nodes @@ -461,6 +507,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64); break; // type_conversion @@ -499,6 +546,13 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); break; case Op_RESCALE: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); |