diff options
Diffstat (limited to 'compute_kernel_writer/src/cl/CLHelpers.cpp')
-rw-r--r-- | compute_kernel_writer/src/cl/CLHelpers.cpp | 43 |
1 files changed, 41 insertions, 2 deletions
diff --git a/compute_kernel_writer/src/cl/CLHelpers.cpp b/compute_kernel_writer/src/cl/CLHelpers.cpp index 08108e383f..f62e1c28e6 100644 --- a/compute_kernel_writer/src/cl/CLHelpers.cpp +++ b/compute_kernel_writer/src/cl/CLHelpers.cpp @@ -21,10 +21,13 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + #include "src/cl/CLHelpers.h" + #include "ckw/Error.h" #include "ckw/types/DataType.h" #include "ckw/types/TensorStorageType.h" +#include "src/types/DataTypeHelpers.h" namespace ckw { @@ -142,10 +145,46 @@ std::string cl_get_variable_storagetype_as_string(TensorStorageType storage) return res; } +std::tuple<bool, std::string> cl_get_unary_op(UnaryOp op) +{ + switch(op) + { + case UnaryOp::LogicalNot: + return { false, "!" }; + + case UnaryOp::BitwiseNot: + return { false, "~" }; + + case UnaryOp::Exp: + return { true, "exp" }; + + case UnaryOp::Tanh: + return { true, "tanh" }; + + case UnaryOp::Sqrt: + return { true, "sqrt" }; + + case UnaryOp::Erf: + return { true, "erf" }; + + case UnaryOp::Fabs: + return { true, "fabs" }; + + case UnaryOp::Log: + return { true, "log" }; + + case UnaryOp::Round: + return { true, "round" }; + + default: + CKW_THROW_MSG("Unsupported unary operation!"); + } +} + std::string cl_data_type_rounded_up_to_valid_vector_width(DataType dt, int32_t width) { - std::string data_type; - const int32_t w = cl_round_up_to_nearest_valid_vector_width(width); + std::string data_type; + const int32_t w = cl_round_up_to_nearest_valid_vector_width(width); data_type += cl_get_variable_datatype_as_string(dt, 1); if(w != 1) { |