aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src/cl/CLHelpers.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/src/cl/CLHelpers.cpp')
-rw-r--r--compute_kernel_writer/src/cl/CLHelpers.cpp77
1 files changed, 39 insertions, 38 deletions
diff --git a/compute_kernel_writer/src/cl/CLHelpers.cpp b/compute_kernel_writer/src/cl/CLHelpers.cpp
index ff4408b1a3..8e4a932764 100644
--- a/compute_kernel_writer/src/cl/CLHelpers.cpp
+++ b/compute_kernel_writer/src/cl/CLHelpers.cpp
@@ -28,6 +28,7 @@
#include "ckw/types/DataType.h"
#include "ckw/types/Operators.h"
#include "ckw/types/TensorStorageType.h"
+
#include "src/types/DataTypeHelpers.h"
namespace ckw
@@ -35,7 +36,7 @@ namespace ckw
bool cl_validate_vector_length(int32_t len)
{
bool valid_vector_length = true;
- if(len < 1 || len > 16 || (len > 4 && len < 8) || (len > 8 && len < 16))
+ if (len < 1 || len > 16 || (len > 4 && len < 8) || (len > 8 && len < 16))
{
valid_vector_length = false;
}
@@ -44,14 +45,14 @@ bool cl_validate_vector_length(int32_t len)
std::string cl_get_variable_datatype_as_string(DataType dt, int32_t len)
{
- if(cl_validate_vector_length(len) == false)
+ if (cl_validate_vector_length(len) == false)
{
CKW_THROW_MSG("Unsupported vector length");
return "";
}
std::string res;
- switch(dt)
+ switch (dt)
{
case DataType::Fp32:
res += "float";
@@ -85,7 +86,7 @@ std::string cl_get_variable_datatype_as_string(DataType dt, int32_t len)
return "";
}
- if(len > 1)
+ if (len > 1)
{
res += std::to_string(len);
}
@@ -95,7 +96,7 @@ std::string cl_get_variable_datatype_as_string(DataType dt, int32_t len)
int32_t cl_round_up_to_nearest_valid_vector_width(int32_t width)
{
- switch(width)
+ switch (width)
{
case 1:
return 1;
@@ -128,7 +129,7 @@ int32_t cl_round_up_to_nearest_valid_vector_width(int32_t width)
std::string cl_get_variable_storagetype_as_string(TensorStorageType storage)
{
std::string res;
- switch(storage)
+ switch (storage)
{
case TensorStorageType::BufferUint8Ptr:
res += "__global uchar*";
@@ -148,7 +149,7 @@ std::string cl_get_variable_storagetype_as_string(TensorStorageType storage)
std::string cl_get_assignment_op_as_string(AssignmentOp op)
{
- switch(op)
+ switch (op)
{
case AssignmentOp::Increment:
return "+=";
@@ -163,34 +164,34 @@ std::string cl_get_assignment_op_as_string(AssignmentOp op)
std::tuple<bool, std::string> cl_get_unary_op(UnaryOp op)
{
- switch(op)
+ switch (op)
{
case UnaryOp::LogicalNot:
- return { false, "!" };
+ return {false, "!"};
case UnaryOp::BitwiseNot:
- return { false, "~" };
+ return {false, "~"};
case UnaryOp::Exp:
- return { true, "exp" };
+ return {true, "exp"};
case UnaryOp::Tanh:
- return { true, "tanh" };
+ return {true, "tanh"};
case UnaryOp::Sqrt:
- return { true, "sqrt" };
+ return {true, "sqrt"};
case UnaryOp::Erf:
- return { true, "erf" };
+ return {true, "erf"};
case UnaryOp::Fabs:
- return { true, "fabs" };
+ return {true, "fabs"};
case UnaryOp::Log:
- return { true, "log" };
+ return {true, "log"};
case UnaryOp::Round:
- return { true, "round" };
+ return {true, "round"};
default:
CKW_THROW_MSG("Unsupported unary operation!");
@@ -201,52 +202,52 @@ std::tuple<bool, std::string> cl_get_binary_op(BinaryOp op, DataType data_type)
{
const auto is_float = is_data_type_float(data_type);
- switch(op)
+ switch (op)
{
case BinaryOp::Add:
- return { false, "+" };
+ return {false, "+"};
case BinaryOp::Sub:
- return { false, "-" };
+ return {false, "-"};
case BinaryOp::Mul:
- return { false, "*" };
+ return {false, "*"};
case BinaryOp::Div:
- return { false, "/" };
+ return {false, "/"};
case BinaryOp::Mod:
- return { false, "%" };
+ return {false, "%"};
case BinaryOp::Equal:
- return { false, "==" };
+ return {false, "=="};
case BinaryOp::Less:
- return { false, "<" };
+ return {false, "<"};
case BinaryOp::LessEqual:
- return { false, "<=" };
+ return {false, "<="};
case BinaryOp::Greater:
- return { false, ">" };
+ return {false, ">"};
case BinaryOp::GreaterEqual:
- return { false, ">=" };
+ return {false, ">="};
case BinaryOp::LogicalAnd:
- return { false, "&&" };
+ return {false, "&&"};
case BinaryOp::LogicalOr:
- return { false, "||" };
+ return {false, "||"};
case BinaryOp::BitwiseXOR:
- return { false, "^" };
+ return {false, "^"};
case BinaryOp::Min:
- return { true, is_float ? "fmin" : "min" };
+ return {true, is_float ? "fmin" : "min"};
case BinaryOp::Max:
- return { true, is_float ? "fmax" : "max" };
+ return {true, is_float ? "fmax" : "max"};
default:
CKW_THROW_MSG("Unsupported binary operator/function!");
@@ -255,13 +256,13 @@ std::tuple<bool, std::string> cl_get_binary_op(BinaryOp op, DataType data_type)
std::tuple<bool, std::string> cl_get_ternary_op(TernaryOp op)
{
- switch(op)
+ switch (op)
{
case TernaryOp::Select:
- return { true, "select" };
+ return {true, "select"};
case TernaryOp::Clamp:
- return { true, "clamp" };
+ return {true, "clamp"};
default:
CKW_THROW_MSG("Unsupported ternary function!");
@@ -273,7 +274,7 @@ std::string cl_data_type_rounded_up_to_valid_vector_width(DataType dt, int32_t w
std::string data_type;
const int32_t w = cl_round_up_to_nearest_valid_vector_width(width);
data_type += cl_get_variable_datatype_as_string(dt, 1);
- if(w != 1)
+ if (w != 1)
{
data_type += std::to_string(w);
}
@@ -284,7 +285,7 @@ std::vector<int32_t> cl_decompose_vector_width(int32_t vector_width)
{
std::vector<int32_t> x;
- switch(vector_width)
+ switch (vector_width)
{
case 0:
break;