diff options
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index bad0c40..4a06248 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -334,14 +334,17 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, // scatter_gather case Op_GATHER: - { - // output.rank = input.rank - 1 + index.rank - int32_t index_rank = outputRank - inputRank + 1; - DEF_FACTORY_GATHER(OpGather, AINT8); - DEF_FACTORY_GATHER(OpGather, INT16); - DEF_FACTORY_GATHER(OpGather, INT32); - } - break; + DEF_FACTORY_ONE_TYPE(OpGather, AINT8); + DEF_FACTORY_ONE_TYPE(OpGather, INT16); + DEF_FACTORY_ONE_TYPE(OpGather, INT32); + DEF_FACTORY_ONE_TYPE(OpGather, FLOAT); + break; + case Op_SCATTER: + DEF_FACTORY_ONE_TYPE(OpScatter, AINT8); + DEF_FACTORY_ONE_TYPE(OpScatter, INT16); + DEF_FACTORY_ONE_TYPE(OpScatter, INT32); + DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT); + break; // image case Op_RESIZE: @@ -349,6 +352,7 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT8); DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT48); DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT16); + DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, FLOAT, FLOAT); break; // data_nodes |