aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-08-22 11:11:23 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-08-24 09:48:58 +0000
commit34b6c3a08c3fd3f99cf675921a319b8678a98273 (patch)
treeb52068faf874063b79f0e4ddd7d587e785bb65bf /compute_kernel_writer
parent3a9ecdfdc76abd7f9acdab42a1f7e4c0188d6f48 (diff)
downloadComputeLibrary-34b6c3a08c3fd3f99cf675921a319b8678a98273.tar.gz
Add CKW binary and ternary statements
Resolves: COMPMID-6388 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: Ia0cd1486f368af54053066f489cac83b9de01789 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10182 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'compute_kernel_writer')
-rw-r--r--compute_kernel_writer/include/ckw/KernelWriter.h23
-rw-r--r--compute_kernel_writer/include/ckw/types/Operators.h43
-rw-r--r--compute_kernel_writer/src/cl/CLHelpers.cpp71
-rw-r--r--compute_kernel_writer/src/cl/CLHelpers.h27
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.cpp119
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.h6
-rw-r--r--compute_kernel_writer/validation/Validation.cpp6
-rw-r--r--compute_kernel_writer/validation/tests/CLKernelWriterBinaryOpTest.h133
-rw-r--r--compute_kernel_writer/validation/tests/CLKernelWriterTernaryOpTest.h111
-rw-r--r--compute_kernel_writer/validation/tests/CLKernelWriterUnaryExpressionTest.h4
10 files changed, 538 insertions, 5 deletions
diff --git a/compute_kernel_writer/include/ckw/KernelWriter.h b/compute_kernel_writer/include/ckw/KernelWriter.h
index 7eb6d2894a..d59867fa6f 100644
--- a/compute_kernel_writer/include/ckw/KernelWriter.h
+++ b/compute_kernel_writer/include/ckw/KernelWriter.h
@@ -100,10 +100,29 @@ public:
/** Write the unary expression statement: `<dst> = <op> <src>;`.
*
* @param[in] dst The destination tile.
- * @param[in] src The source tile.
* @param[in] op The unary operator.
+ * @param[in] src The source tile.
+ */
+ virtual void op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src) = 0;
+
+ /** Write the binary expression statement: `<dst> = <op>(<first>, <second>);`.
+ *
+ * @param[in] dst The destination tile.
+ * @param[in] op The binary operator.
+ * @param[in] first The first source tile.
+ * @param[in] second The second source tile.
+ */
+ virtual void op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) = 0;
+
+ /** Write ternary expression statement: `<dst> = <op>(<first>, <second>, <third>);`.
+ *
+ * @param[in] dst The destination tile.
+ * @param[in] op The ternary operator.
+ * @param[in] first The first source tile.
+ * @param[in] second The second source tile.
+ * @param[in] third The third source tile.
*/
- virtual void op_unary(const TileOperand &dst, const TileOperand &src, UnaryOp op) = 0;
+ virtual void op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) = 0;
// =============================================================================================
// Misc
diff --git a/compute_kernel_writer/include/ckw/types/Operators.h b/compute_kernel_writer/include/ckw/types/Operators.h
index ec2df08c46..1e5f9bd542 100644
--- a/compute_kernel_writer/include/ckw/types/Operators.h
+++ b/compute_kernel_writer/include/ckw/types/Operators.h
@@ -52,6 +52,49 @@ enum class AssignmentOp : int32_t
Decrement = 0x0001, // -=
};
+/** Binary operators. */
+enum class BinaryOp : int32_t
+{
+ // Elementwise
+ Add = 0x0000, // +
+ Sub = 0x0001, // -
+ Mul = 0x0002, // *
+ Div = 0x0003, // /
+ Mod = 0x0004, // %
+
+ // Relational
+ Equal = 0x1000, // ==
+ Less = 0x1001, // <
+ LessEqual = 0x1002, // <=
+ Greater = 0x1003, // >
+ GreaterEqual = 0x1004, // >=
+
+ // Algebra
+ MatMul_Nt_Nt = 0x2000, // X
+ MatMul_Nt_T = 0x2001, // X
+ MatMul_T_Nt = 0x2002, // X
+ MatMul_T_T = 0x2003, // X
+ Dot = 0x2004, // .
+
+ // Logical
+ LogicalAnd = 0x3000, // &&
+ LogicalOr = 0x3001, // ||
+
+ // Bitwise
+ BitwiseXOR = 0x4000, // ^
+
+ // Functions
+ Min = 0x8000,
+ Max = 0x8001,
+};
+
+/** Ternary operators. */
+enum class TernaryOp : int32_t
+{
+ Select = 0x0000,
+ Clamp = 0x0001,
+};
+
} // namespace ckw
#endif // CKW_INCLUDE_CKW_TYPES_OPERATORS_H
diff --git a/compute_kernel_writer/src/cl/CLHelpers.cpp b/compute_kernel_writer/src/cl/CLHelpers.cpp
index f62e1c28e6..e12e5e1b13 100644
--- a/compute_kernel_writer/src/cl/CLHelpers.cpp
+++ b/compute_kernel_writer/src/cl/CLHelpers.cpp
@@ -181,6 +181,77 @@ std::tuple<bool, std::string> cl_get_unary_op(UnaryOp op)
}
}
+std::tuple<bool, std::string> cl_get_binary_op(BinaryOp op, DataType data_type)
+{
+ const auto is_float = is_data_type_float(data_type);
+
+ switch(op)
+ {
+ case BinaryOp::Add:
+ return { false, "+" };
+
+ case BinaryOp::Sub:
+ return { false, "-" };
+
+ case BinaryOp::Mul:
+ return { false, "*" };
+
+ case BinaryOp::Div:
+ return { false, "/" };
+
+ case BinaryOp::Mod:
+ return { false, "%" };
+
+ case BinaryOp::Equal:
+ return { false, "==" };
+
+ case BinaryOp::Less:
+ return { false, "<" };
+
+ case BinaryOp::LessEqual:
+ return { false, "<=" };
+
+ case BinaryOp::Greater:
+ return { false, ">" };
+
+ case BinaryOp::GreaterEqual:
+ return { false, ">=" };
+
+ case BinaryOp::LogicalAnd:
+ return { false, "&&" };
+
+ case BinaryOp::LogicalOr:
+ return { false, "||" };
+
+ case BinaryOp::BitwiseXOR:
+ return { false, "^" };
+
+ case BinaryOp::Min:
+ return { true, is_float ? "fmin" : "min" };
+
+ case BinaryOp::Max:
+ return { true, is_float ? "fmax" : "max" };
+
+ default:
+ CKW_THROW_MSG("Unsupported binary operator/function!");
+ }
+}
+
+std::tuple<bool, std::string> cl_get_ternary_op(TernaryOp op)
+{
+ switch(op)
+ {
+ case TernaryOp::Select:
+ return { true, "select" };
+
+ case TernaryOp::Clamp:
+ return { true, "clamp" };
+
+ default:
+ CKW_THROW_MSG("Unsupported ternary function!");
+ }
+}
+
std::string cl_data_type_rounded_up_to_valid_vector_width(DataType dt, int32_t width)
{
std::string data_type;
diff --git a/compute_kernel_writer/src/cl/CLHelpers.h b/compute_kernel_writer/src/cl/CLHelpers.h
index 3c1a7724e2..370ffc700c 100644
--- a/compute_kernel_writer/src/cl/CLHelpers.h
+++ b/compute_kernel_writer/src/cl/CLHelpers.h
@@ -70,9 +70,36 @@ std::string cl_get_assignment_op_as_string(AssignmentOp op);
* - str: the function name or the operator in OpenCL language.
*
* @param[in] op The unary operator.
+ *
+ * @return The information about the unary operation.
*/
std::tuple<bool, std::string> cl_get_unary_op(UnaryOp op);
+/** Return the information about the binary operation.
+ *
+ * The result contains:
+ * - is_func: true if it's a function and false if it's an binary operator in OpenCL language.
+ * - str: the function name or the operator in OpenCL language.
+ *
+ * @param[in] op The binary operator.
+ * @param[in] data_type The input data type.
+ *
+ * @return The information about the binary operation.
+ */
+std::tuple<bool, std::string> cl_get_binary_op(BinaryOp op, DataType data_type);
+
+/** Return the information about the ternary operation.
+ *
+ * The result contains:
+ * - is_func: true if it's a function and false if it's a ternary operator in OpenCL language.
+ * - str: the function name or the operator in OpenCL language.
+ *
+ * @param[in] op The ternary operator.
+ *
+ * @return The information about the ternary operation.
+ */
+std::tuple<bool, std::string> cl_get_ternary_op(TernaryOp op);
+
/** Helper function to return the OpenCL vector size that accommodate the the desired width
*
* @param[in] width The desired width
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
index 33d16da926..312162f498 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
@@ -166,7 +166,7 @@ void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, Con
}
}
-void CLKernelWriter::op_unary(const TileOperand &dst, const TileOperand &src, UnaryOp op)
+void CLKernelWriter::op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src)
{
const auto &dst_tile = to_cl_tile(dst);
const auto &src_tile = to_cl_tile(src);
@@ -198,6 +198,123 @@ void CLKernelWriter::op_unary(const TileOperand &dst, const TileOperand &src, Un
}
}
+void CLKernelWriter::op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second)
+{
+ const auto &dst_tile = to_cl_tile(dst);
+ const auto &lhs_tile = to_cl_tile(first);
+ const auto &rhs_tile = to_cl_tile(second);
+
+ const auto dst_w = dst_tile.info().width();
+ const auto dst_h = dst_tile.info().height();
+ const auto lhs_w = lhs_tile.info().width();
+ const auto rhs_w = rhs_tile.info().width();
+
+ const auto data_type = lhs_tile.info().data_type();
+
+ CKW_ASSERT_MSG(lhs_tile.info().data_type() == rhs_tile.info().data_type(), "LHS and RHS type must match.");
+
+ CKW_ASSERT_MSG(lhs_tile.info().height() == dst_h || lhs_tile.info().height() == 1, "LHS tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(rhs_tile.info().height() == dst_h || rhs_tile.info().height() == 1, "RHS tile height must match or source is broadcasting in y dimension.");
+
+ CKW_ASSERT_MSG(lhs_w == dst_w || lhs_w == 1, "LHS tile width must match destination or LHS is broadcasting in x dimension.");
+ CKW_ASSERT_MSG(rhs_w == dst_w || rhs_w == 1, "RHS tile width must match destination or RHS is broadcasting in x dimension.");
+
+ if(op == BinaryOp::MatMul_Nt_T)
+ {
+ CKW_ASSERT(is_data_type_float(data_type));
+
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ for(int32_t x = 0; x < dst_w; ++x)
+ {
+ for(int32_t k = 0; k < lhs_w; ++k)
+ {
+ append_code(
+ dst_tile.scalar(x, y).str, " = fma(",
+ lhs_tile.scalar(k, y).str, ", ",
+ rhs_tile.scalar(k, x).str, ", ",
+ dst_tile.scalar(x, y).str, ");\n");
+ }
+ }
+ }
+ }
+ else
+ {
+ const auto op_info = cl_get_binary_op(op, data_type);
+ const auto op_is_func = std::get<0>(op_info);
+ const auto &op_name = std::get<1>(op_info);
+
+ const auto data_type_str = cl_get_variable_datatype_as_string(data_type, dst_w);
+
+ const auto broadcast_lhs_x = dst_w != 1 && lhs_w == 1;
+ const auto broadcast_rhs_x = dst_w != 1 && rhs_w == 1;
+
+ const std::string lhs_prefix = broadcast_lhs_x ? "(" + data_type_str + ")" : "";
+ const std::string rhs_prefix = broadcast_rhs_x ? "(" + data_type_str + ")" : "";
+
+ const std::string op_prefix = op_is_func ? " = " + op_name + "(" : " = ";
+ const std::string op_separator = op_is_func ? ", " : " " + op_name + " ";
+ const std::string op_suffix = op_is_func ? ");\n" : ";\n";
+
+ // Broadcasting on y dimension is automatic (see CLTile::vector).
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ append_code(dst_tile.vector(y).str, op_prefix, lhs_prefix, lhs_tile.vector(y).str, op_separator, rhs_prefix, rhs_tile.vector(y).str, op_suffix);
+ }
+ }
+}
+
+void CLKernelWriter::op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third)
+{
+ const auto &dst_tile = to_cl_tile(dst);
+ const auto &first_tile = to_cl_tile(first);
+ const auto &second_tile = to_cl_tile(second);
+ const auto &third_tile = to_cl_tile(third);
+
+ const auto dst_w = dst_tile.info().width();
+ const auto dst_h = dst_tile.info().height();
+ const auto first_w = first_tile.info().width();
+ const auto second_w = second_tile.info().width();
+ const auto third_w = third_tile.info().width();
+
+ const auto data_type = dst_tile.info().data_type();
+ const auto data_type_str = cl_get_variable_datatype_as_string(data_type, dst_w);
+
+ const auto op_info = cl_get_ternary_op(op);
+ const auto op_is_func = std::get<0>(op_info);
+ const auto &op_name = std::get<1>(op_info);
+
+ const auto broadcast_first_x = dst_w != 1 && first_w == 1;
+ const auto broadcast_second_x = dst_w != 1 && second_w == 1;
+ const auto broadcast_third_x = dst_w != 1 && third_w == 1;
+
+ const std::string first_prefix = broadcast_first_x ? "(" + data_type_str + ")" : "";
+ const std::string second_prefix = broadcast_second_x ? "(" + data_type_str + ")" : "";
+ const std::string third_prefix = broadcast_third_x ? "(" + data_type_str + ")" : "";
+
+ CKW_ASSERT_MSG(op_is_func, "The only supported ternary operator is function.");
+ CKW_ASSERT_MSG(second_tile.info().data_type() == dst_tile.info().data_type(), "2nd source and destination type must match.");
+ CKW_ASSERT_MSG(third_tile.info().data_type() == dst_tile.info().data_type(), "3rd source and destination type must match.");
+
+ CKW_ASSERT_MSG(first_tile.info().height() == dst_h || first_tile.info().height() == 1, "1st tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(second_tile.info().height() == dst_h || second_tile.info().height() == 1, "2nd tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(third_tile.info().height() == dst_h || third_tile.info().height() == 1, "3rd tile height must match or source is broadcasting in y dimension.");
+
+ CKW_ASSERT_MSG(first_w == dst_w || first_w == 1, "1st tile width must match or source is broadcasting in x dimension.");
+ CKW_ASSERT_MSG(second_w == dst_w || second_w == 1, "2nd tile width must match or source is broadcasting in x dimension.");
+ CKW_ASSERT_MSG(third_w == dst_w || third_w == 1, "3rd tile width must match or source is broadcasting in x dimension.");
+
+ // Broadcasting on y dimension is automatic (see CLTile::vector).
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ append_code(
+ dst_tile.vector(y).str, " = ", op_name, "(",
+ first_prefix, first_tile.vector(y).str, ", ",
+ second_prefix, second_tile.vector(y).str, ", ",
+ third_prefix, third_tile.vector(y).str, ");\n");
+ }
+}
+
void CLKernelWriter::op_comment(const std::string &text)
{
#ifdef COMPUTE_KERNEL_WRITER_DEBUG_ENABLED
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.h b/compute_kernel_writer/src/cl/CLKernelWriter.h
index ea455a7fdd..2a6b79c691 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.h
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.h
@@ -64,7 +64,11 @@ public:
void op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy) override;
- void op_unary(const TileOperand &dst, const TileOperand &src, UnaryOp op) override;
+ void op_unary(const TileOperand &dst, UnaryOp op, const TileOperand &src) override;
+
+ void op_binary(const TileOperand &dst, BinaryOp op, const TileOperand &first, const TileOperand &second) override;
+
+ void op_ternary(const TileOperand &dst, TernaryOp op, const TileOperand &first, const TileOperand &second, const TileOperand &third) override;
// =============================================================================================
// Misc
diff --git a/compute_kernel_writer/validation/Validation.cpp b/compute_kernel_writer/validation/Validation.cpp
index c55c7c0c07..20957d90a5 100644
--- a/compute_kernel_writer/validation/Validation.cpp
+++ b/compute_kernel_writer/validation/Validation.cpp
@@ -24,11 +24,13 @@
#include "validation/tests/CLConstantTileTest.hpp"
#include "validation/tests/CLKernelWriterAssignTest.h"
+#include "validation/tests/CLKernelWriterBinaryOpTest.h"
#include "validation/tests/CLKernelWriterCastTest.h"
#include "validation/tests/CLKernelWriterCommentTest.h"
#include "validation/tests/CLKernelWriterDeclareTensorTest.h"
#include "validation/tests/CLKernelWriterDeclareTileTest.h"
#include "validation/tests/CLKernelWriterOpLoadStoreTest.h"
+#include "validation/tests/CLKernelWriterTernaryOpTest.h"
#include "validation/tests/CLKernelWriterUnaryExpressionTest.h"
#include "validation/tests/CLTensorArgumentTest.h"
#include "validation/tests/CLTileTest.hpp"
@@ -83,6 +85,8 @@ int32_t main()
const auto test26 = std::make_unique<CLKernelWriterAssignTest>();
const auto test27 = std::make_unique<CLKernelWriterCastTest>();
const auto test28 = std::make_unique<CLKernelWriterUnaryExpressionTest>();
+ const auto test29 = std::make_unique<CLKernelWriterBinaryOpTest>();
+ const auto test30 = std::make_unique<CLKernelWriterTernaryOpTest>();
tests.push_back(test3.get());
tests.push_back(test4.get());
@@ -112,6 +116,8 @@ int32_t main()
tests.push_back(test26.get());
tests.push_back(test27.get());
tests.push_back(test28.get());
+ tests.push_back(test29.get());
+ tests.push_back(test30.get());
#endif /* COMPUTE_KERNEL_WRITER_OPENCL_ENABLED */
bool all_test_passed = true;
diff --git a/compute_kernel_writer/validation/tests/CLKernelWriterBinaryOpTest.h b/compute_kernel_writer/validation/tests/CLKernelWriterBinaryOpTest.h
new file mode 100644
index 0000000000..bfa6724008
--- /dev/null
+++ b/compute_kernel_writer/validation/tests/CLKernelWriterBinaryOpTest.h
@@ -0,0 +1,133 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_VALIDATION_TESTS_CLKERNELWRITERBINARYOPTEST_H
+#define CKW_VALIDATION_TESTS_CLKERNELWRITERBINARYOPTEST_H
+
+#include "ckw/TileInfo.h"
+#include "ckw/types/DataType.h"
+#include "src/cl/CLKernelWriter.h"
+#include "validation/tests/common/Common.h"
+#include "validation/tests/common/KernelWriterInterceptor.h"
+
+#include <cstdint>
+#include <vector>
+
+namespace ckw
+{
+
+class CLKernelWriterBinaryOpTest : public ITest
+{
+public:
+ CLKernelWriterBinaryOpTest()
+ {
+ // dst_height, dst_width, dst_data_type, lhs_height, lhs_width, rhs_height, rhs_width, src_data_type, op, expected_code
+ _tests.push_back({ 1, 1, DataType::Fp32, 1, 1, 1, 1, DataType::Fp32, BinaryOp::Add, "G0__dst = G0__lhs + G0__rhs;\n" }); // Scalar.
+
+ _tests.push_back({ 1, 3, DataType::Bool, 1, 3, 1, 3, DataType::Fp16, BinaryOp::Equal, "G0__dst = G0__lhs == G0__rhs;\n" }); // Whole vector.
+
+ _tests.push_back({ 2, 4, DataType::Int8, 2, 4, 2, 4, DataType::Int8, BinaryOp::Min, "G0__dst__0 = min(G0__lhs__0, G0__rhs__0);\nG0__dst__1 = min(G0__lhs__1, G0__rhs__1);\n" }); // Whole tile.
+
+ _tests.push_back({ 2, 3, DataType::Uint8, 1, 3, 2, 3, DataType::Uint8, BinaryOp::BitwiseXOR, "G0__dst__0 = G0__lhs ^ G0__rhs__0;\nG0__dst__1 = G0__lhs ^ G0__rhs__1;\n" }); // LHS y-dimension broadcast.
+
+ _tests.push_back({ 2, 3, DataType::Bool, 2, 3, 1, 3, DataType::Fp32, BinaryOp::Less, "G0__dst__0 = G0__lhs__0 < G0__rhs;\nG0__dst__1 = G0__lhs__1 < G0__rhs;\n" }); // RHS y-dimension broadcast.
+
+ _tests.push_back({ 2, 3, DataType::Fp16, 1, 3, 1, 3, DataType::Fp16, BinaryOp::Max, "G0__dst__0 = fmax(G0__lhs, G0__rhs);\nG0__dst__1 = fmax(G0__lhs, G0__rhs);\n" }); // LHS and RHS y-dimension broadcast.
+
+ _tests.push_back({ 2, 4, DataType::Fp32, 2, 1, 2, 4, DataType::Fp32, BinaryOp::Div, "G0__dst__0 = (float4)G0__lhs__0 / G0__rhs__0;\nG0__dst__1 = (float4)G0__lhs__1 / G0__rhs__1;\n" }); // LHS x-dimension broadcast.
+
+ _tests.push_back({ 2, 4, DataType::Fp16, 2, 4, 2, 1, DataType::Fp16, BinaryOp::Mod, "G0__dst__0 = G0__lhs__0 % (half4)G0__rhs__0;\nG0__dst__1 = G0__lhs__1 % (half4)G0__rhs__1;\n" }); // RHS x-dimension broadcast.
+
+ _tests.push_back({ 2, 4, DataType::Bool, 2, 1, 2, 1, DataType::Fp32, BinaryOp::GreaterEqual, "G0__dst__0 = (float4)G0__lhs__0 >= (float4)G0__rhs__0;\nG0__dst__1 = (float4)G0__lhs__1 >= (float4)G0__rhs__1;\n" }); // LHS and RHS x-dimension broadcast.
+
+ _tests.push_back({ 2, 3, DataType::Fp32, 2, 3, 2, 3, DataType::Fp32, BinaryOp::MatMul_Nt_T,
+ "G0__dst__0.s0 = fma(G0__lhs__0.s0, G0__rhs__0.s0, G0__dst__0.s0);\n"
+ "G0__dst__0.s0 = fma(G0__lhs__1.s0, G0__rhs__1.s0, G0__dst__0.s0);\n"
+ "G0__dst__0.s0 = fma(G0__lhs__1.s0, G0__rhs__1.s0, G0__dst__0.s0);\n"
+ "G0__dst__1.s0 = fma(G0__lhs__0.s0, G0__rhs__0.s1, G0__dst__1.s0);\n"
+ "G0__dst__1.s0 = fma(G0__lhs__1.s0, G0__rhs__1.s1, G0__dst__1.s0);\n"
+ "G0__dst__1.s0 = fma(G0__lhs__1.s0, G0__rhs__1.s1, G0__dst__1.s0);\n"
+ "G0__dst__1.s0 = fma(G0__lhs__0.s0, G0__rhs__0.s2, G0__dst__1.s0);\n"
+ "G0__dst__1.s0 = fma(G0__lhs__1.s0, G0__rhs__1.s2, G0__dst__1.s0);\n"
+ "G0__dst__1.s0 = fma(G0__lhs__1.s0, G0__rhs__1.s2, G0__dst__1.s0);\n"
+ "G0__dst__0.s1 = fma(G0__lhs__0.s1, G0__rhs__0.s0, G0__dst__0.s1);\n"
+ "G0__dst__0.s1 = fma(G0__lhs__1.s1, G0__rhs__1.s0, G0__dst__0.s1);\n"
+ "G0__dst__0.s1 = fma(G0__lhs__1.s1, G0__rhs__1.s0, G0__dst__0.s1);\n"
+ "G0__dst__1.s1 = fma(G0__lhs__0.s1, G0__rhs__0.s1, G0__dst__1.s1);\n"
+ "G0__dst__1.s1 = fma(G0__lhs__1.s1, G0__rhs__1.s1, G0__dst__1.s1);\n"
+ "G0__dst__1.s1 = fma(G0__lhs__1.s1, G0__rhs__1.s1, G0__dst__1.s1);\n"
+ "G0__dst__1.s1 = fma(G0__lhs__0.s1, G0__rhs__0.s2, G0__dst__1.s1);\n"
+ "G0__dst__1.s1 = fma(G0__lhs__1.s1, G0__rhs__1.s2, G0__dst__1.s1);\n"
+ "G0__dst__1.s1 = fma(G0__lhs__1.s1, G0__rhs__1.s2, G0__dst__1.s1);\n" });
+ }
+
+ bool run() override
+ {
+ int32_t test_no = 0;
+ bool all_tests_passed = true;
+
+ for(const auto &test : _tests)
+ {
+ KernelWriterInterceptor<CLKernelWriter> writer;
+
+ auto dst = writer.declare_tile("dst", TileInfo(test.dst_data_type, test.dst_height, test.dst_width));
+ auto lhs = writer.declare_tile("lhs", TileInfo(test.src_data_type, test.lhs_height, test.lhs_width));
+ auto rhs = writer.declare_tile("rhs", TileInfo(test.src_data_type, test.rhs_height, test.rhs_width));
+
+ writer.start_capture_code();
+
+ writer.op_binary(dst, test.op, lhs, rhs);
+
+ VALIDATE_TEST(writer.check_added_code(test.expected_code), all_tests_passed, test_no++);
+ }
+
+ return all_tests_passed;
+ }
+
+ std::string name() override
+ {
+ return "CLKernelWriterBinaryOpTest";
+ }
+
+private:
+ struct TestInfo
+ {
+ int32_t dst_height;
+ int32_t dst_width;
+ DataType dst_data_type;
+ int32_t lhs_height;
+ int32_t lhs_width;
+ int32_t rhs_height;
+ int32_t rhs_width;
+ DataType src_data_type;
+ BinaryOp op;
+ std::string expected_code;
+ };
+
+ std::vector<TestInfo> _tests{};
+};
+
+} // namespace ckw
+
+#endif // CKW_VALIDATION_TESTS_CLKERNELWRITERBINARYOPTEST_H
diff --git a/compute_kernel_writer/validation/tests/CLKernelWriterTernaryOpTest.h b/compute_kernel_writer/validation/tests/CLKernelWriterTernaryOpTest.h
new file mode 100644
index 0000000000..d25d3e2958
--- /dev/null
+++ b/compute_kernel_writer/validation/tests/CLKernelWriterTernaryOpTest.h
@@ -0,0 +1,111 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_VALIDATION_TESTS_CLKERNELWRITERTERNARYOPTEST_H
+#define CKW_VALIDATION_TESTS_CLKERNELWRITERTERNARYOPTEST_H
+
+#include "ckw/TileInfo.h"
+#include "ckw/types/DataType.h"
+#include "ckw/types/Operators.h"
+#include "src/cl/CLKernelWriter.h"
+#include "validation/tests/common/Common.h"
+#include "validation/tests/common/KernelWriterInterceptor.h"
+
+#include <cstdint>
+#include <vector>
+
+namespace ckw
+{
+
+class CLKernelWriterTernaryOpTest : public ITest
+{
+public:
+ CLKernelWriterTernaryOpTest()
+ {
+ // dst_height, dst_width, op0_height, op0_width, op1_height, op1_width, op2_height, op2_width, data_type, op, expected_code
+
+ _tests.push_back({ 1, 1, 1, 1, 1, 1, 1, 1, DataType::Fp32, TernaryOp::Select, "G0__dst = select(G0__op0, G0__op1, G0__op2);\n" }); // Scalar.
+
+ _tests.push_back({ 1, 3, 1, 3, 1, 3, 1, 3, DataType::Fp16, TernaryOp::Clamp, "G0__dst = clamp(G0__op0, G0__op1, G0__op2);\n" }); // Whole vector.
+
+ _tests.push_back({ 2, 4, 2, 4, 2, 4, 2, 4, DataType::Int8, TernaryOp::Select, "G0__dst__0 = select(G0__op0__0, G0__op1__0, G0__op2__0);\nG0__dst__1 = select(G0__op0__1, G0__op1__1, G0__op2__1);\n" }); // Whole tile.
+
+ _tests.push_back({ 2, 3, 1, 3, 2, 3, 2, 3, DataType::Uint8, TernaryOp::Clamp, "G0__dst__0 = clamp(G0__op0, G0__op1__0, G0__op2__0);\nG0__dst__1 = clamp(G0__op0, G0__op1__1, G0__op2__1);\n" }); // 1st operand y-dimension broadcast.
+
+ _tests.push_back({ 2, 3, 2, 3, 2, 1, 2, 3, DataType::Fp32, TernaryOp::Select, "G0__dst__0 = select(G0__op0__0, (float3)G0__op1__0, G0__op2__0);\nG0__dst__1 = select(G0__op0__1, (float3)G0__op1__1, G0__op2__1);\n" }); // 2nd operand x-dimension broadcast.
+
+ _tests.push_back({ 2, 3, 1, 3, 2, 1, 1, 1, DataType::Fp16, TernaryOp::Clamp, "G0__dst__0 = clamp(G0__op0, (half3)G0__op1__0, (half3)G0__op2);\nG0__dst__1 = clamp(G0__op0, (half3)G0__op1__1, (half3)G0__op2);\n" }); // 1st operand y-, 2nd operand x-, 3rd operand x- and y-dimension broadcast.
+ }
+
+ bool run() override
+ {
+ int32_t test_no = 0;
+ bool all_tests_passed = true;
+
+ for(const auto &test : _tests)
+ {
+ KernelWriterInterceptor<CLKernelWriter> writer;
+
+ auto dst = writer.declare_tile("dst", TileInfo(test.data_type, test.dst_height, test.dst_width));
+ auto op0 = writer.declare_tile("op0", TileInfo(DataType::Bool, test.op0_height, test.op0_width));
+ auto op1 = writer.declare_tile("op1", TileInfo(test.data_type, test.op1_height, test.op1_width));
+ auto op2 = writer.declare_tile("op2", TileInfo(test.data_type, test.op2_height, test.op2_width));
+
+ writer.start_capture_code();
+
+ writer.op_ternary(dst, test.op, op0, op1, op2);
+
+ VALIDATE_TEST(writer.check_added_code(test.expected_code), all_tests_passed, test_no++);
+ }
+
+ return all_tests_passed;
+ }
+
+ std::string name() override
+ {
+ return "CLKernelWriterTernaryOpTest";
+ }
+
+private:
+ struct TestInfo
+ {
+ int32_t dst_height;
+ int32_t dst_width;
+ int32_t op0_height;
+ int32_t op0_width;
+ int32_t op1_height;
+ int32_t op1_width;
+ int32_t op2_height;
+ int32_t op2_width;
+ DataType data_type;
+ TernaryOp op;
+ std::string expected_code;
+ };
+
+ std::vector<TestInfo> _tests{};
+};
+
+} // namespace ckw
+
+#endif // CKW_VALIDATION_TESTS_CLKERNELWRITERTERNARYOPTEST_H
diff --git a/compute_kernel_writer/validation/tests/CLKernelWriterUnaryExpressionTest.h b/compute_kernel_writer/validation/tests/CLKernelWriterUnaryExpressionTest.h
index 65440a0a99..395a2fe817 100644
--- a/compute_kernel_writer/validation/tests/CLKernelWriterUnaryExpressionTest.h
+++ b/compute_kernel_writer/validation/tests/CLKernelWriterUnaryExpressionTest.h
@@ -43,6 +43,8 @@ class CLKernelWriterUnaryExpressionTest : public ITest
public:
CLKernelWriterUnaryExpressionTest()
{
+ // dst_height, dst_width, src_height, src_width, data_type, op, expected_code
+
_tests.push_back({ 1, 1, 1, 1, DataType::Uint32, UnaryOp::BitwiseNot, "G0__dst = ~G0__src;\n" }); // Scalar.
_tests.push_back({ 1, 3, 1, 3, DataType::Int16, UnaryOp::LogicalNot, "G0__dst = !G0__src;\n" }); // Whole vector.
@@ -70,7 +72,7 @@ public:
writer.start_capture_code();
- writer.op_unary(dst, src, test.op);
+ writer.op_unary(dst, test.op, src);
VALIDATE_TEST(writer.check_added_code(test.expected_code), all_tests_passed, test_no++);
}