diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index a5e1a20..917d56e 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -142,6 +142,7 @@ int OpAdd<Rank, Dtype>::register_fcn() return static_cast<InEigenType>(res_in_64); }; break; + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; break; @@ -369,6 +370,7 @@ int OpMaximum<Rank, Dtype>::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; @@ -385,6 +387,7 @@ int OpMinimum<Rank, Dtype>::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; @@ -403,6 +406,7 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn() switch (InDtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; break; @@ -452,6 +456,7 @@ int OpPow<Rank, Dtype>::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); }; break; @@ -476,6 +481,7 @@ int OpSub<Rank, Dtype>::register_fcn() return static_cast<InEigenType>(res_in_64); }; break; + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; break; @@ -574,6 +580,7 @@ int OpTable<Rank, InDtype>::eval() } // template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); @@ -609,24 +616,32 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); +// Instantiation of nodes for comparison operators opEqual, opGreater +// and opGreaterEqual +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); |