aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/op_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r--reference_model/src/ops/op_factory.cc76
1 files changed, 31 insertions, 45 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 4a06248..b326c63 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2021, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -49,57 +49,53 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
// tensor_ops
case Op_ARGMAX:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
break;
case Op_AVG_POOL2D:
DEF_FACTORY_ONE_TYPE(OpAvgPool2d, FLOAT);
- DEF_FACTORY_ONE_TYPE(OpAvgPool2d, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT16);
break;
case Op_CONV2D:
DEF_FACTORY_TWO_TYPE(OpConv2d, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT8);
DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8);
break;
case Op_DEPTHWISE_CONV2D:
DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8);
DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8);
break;
case Op_FULLY_CONNECTED:
DEF_FACTORY_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT8, INT8);
DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT16, INT8);
break;
case Op_MATMUL:
DEF_FACTORY_ONE_TYPE(OpMatMul, FLOAT);
- DEF_FACTORY_ONE_TYPE(OpMatMul, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpMatMul, INT8);
DEF_FACTORY_ONE_TYPE(OpMatMul, INT16);
break;
case Op_MAX_POOL2D:
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT);
- DEF_FACTORY_ONE_TYPE(OpMaxPool2d, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
break;
case Op_TRANSPOSE_CONV2D:
DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT8, INT8);
DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT16, INT8);
break;
// activation_funcs
case Op_CLAMP:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
break;
case Op_RELUN:
@@ -124,17 +120,17 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
break;
case Op_BITWISE_AND:
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
break;
case Op_BITWISE_OR:
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
break;
case Op_BITWISE_XOR:
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
break;
@@ -188,7 +184,7 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
break;
case Op_BITWISE_NOT:
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
break;
@@ -212,7 +208,7 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
break;
case Op_NEGATE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
break;
@@ -226,7 +222,6 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
// ewise_ternary
case Op_SELECT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, AINT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
@@ -256,13 +251,13 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
break;
case Op_REDUCE_MAX:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
break;
case Op_REDUCE_MIN:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
break;
@@ -277,7 +272,6 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
// data layout
case Op_CONCAT:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, AINT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32);
@@ -286,14 +280,12 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
case Op_PAD:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, AINT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
break;
case Op_RESHAPE:
DEF_FACTORY_RESHAPE(OpReshape, FLOAT);
- DEF_FACTORY_RESHAPE(OpReshape, AINT8);
DEF_FACTORY_RESHAPE(OpReshape, INT8);
DEF_FACTORY_RESHAPE(OpReshape, INT16);
DEF_FACTORY_RESHAPE(OpReshape, INT32);
@@ -301,7 +293,6 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
break;
case Op_REVERSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, AINT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
@@ -309,7 +300,6 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
break;
case Op_SLICE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, AINT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
@@ -317,7 +307,6 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
break;
case Op_TILE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, AINT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
@@ -326,7 +315,6 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
case Op_TRANSPOSE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, AINT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
@@ -334,13 +322,13 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
// scatter_gather
case Op_GATHER:
- DEF_FACTORY_ONE_TYPE(OpGather, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpGather, INT8);
DEF_FACTORY_ONE_TYPE(OpGather, INT16);
DEF_FACTORY_ONE_TYPE(OpGather, INT32);
DEF_FACTORY_ONE_TYPE(OpGather, FLOAT);
break;
case Op_SCATTER:
- DEF_FACTORY_ONE_TYPE(OpScatter, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpScatter, INT8);
DEF_FACTORY_ONE_TYPE(OpScatter, INT16);
DEF_FACTORY_ONE_TYPE(OpScatter, INT32);
DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT);
@@ -363,7 +351,6 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
case Op_IDENTITY:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
@@ -371,7 +358,6 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
case Op_IDENTITYN:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL);
@@ -399,20 +385,20 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
break;
case Op_RESCALE:
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, AINT8);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT16);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT32);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, AINT8);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, UINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);
break;
// custom