aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/op_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r--reference_model/src/ops/op_factory.cc20
1 files changed, 12 insertions, 8 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index bad0c40..4a06248 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -334,14 +334,17 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
// scatter_gather
case Op_GATHER:
- {
- // output.rank = input.rank - 1 + index.rank
- int32_t index_rank = outputRank - inputRank + 1;
- DEF_FACTORY_GATHER(OpGather, AINT8);
- DEF_FACTORY_GATHER(OpGather, INT16);
- DEF_FACTORY_GATHER(OpGather, INT32);
- }
- break;
+ DEF_FACTORY_ONE_TYPE(OpGather, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpGather, INT16);
+ DEF_FACTORY_ONE_TYPE(OpGather, INT32);
+ DEF_FACTORY_ONE_TYPE(OpGather, FLOAT);
+ break;
+ case Op_SCATTER:
+ DEF_FACTORY_ONE_TYPE(OpScatter, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpScatter, INT16);
+ DEF_FACTORY_ONE_TYPE(OpScatter, INT32);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT);
+ break;
// image
case Op_RESIZE:
@@ -349,6 +352,7 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT8);
DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT48);
DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT16);
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, FLOAT, FLOAT);
break;
// data_nodes