aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-07-31 17:13:34 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-08-22 08:42:23 +0000
commite1c3b466960d5e3fd5a54871287f5eb6102bfb8c (patch)
treeca7b46273f564cd96bbb6832fbcd743ce4642301 /compute_kernel_writer
parent47a396e3aae96f2dcad44f4e0d6cb6b87b368395 (diff)
downloadComputeLibrary-e1c3b466960d5e3fd5a54871287f5eb6102bfb8c.tar.gz
Add CKW writing methods for CL unary ops
* Add writing methods for: - Assignment. - Cast. - Unary expression. * Add corresponding tests. Partially resolves: COMPMID-6388. Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: Ia654173e2e1ee9cddb7819980251e0591934439f Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10155 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-by: Jakub Sujak <jakub.sujak@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'compute_kernel_writer')
-rw-r--r--compute_kernel_writer/CMakeLists.txt1
-rw-r--r--compute_kernel_writer/include/ckw/Error.h7
-rw-r--r--compute_kernel_writer/include/ckw/KernelWriter.h62
-rw-r--r--compute_kernel_writer/include/ckw/types/ConvertPolicy.h41
-rw-r--r--compute_kernel_writer/include/ckw/types/Operators.h57
-rw-r--r--compute_kernel_writer/src/cl/CLHelpers.cpp43
-rw-r--r--compute_kernel_writer/src/cl/CLHelpers.h21
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.cpp146
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.h45
-rw-r--r--compute_kernel_writer/src/cl/CLTile.cpp4
-rw-r--r--compute_kernel_writer/src/types/DataTypeHelpers.cpp35
-rw-r--r--compute_kernel_writer/src/types/DataTypeHelpers.h43
-rw-r--r--compute_kernel_writer/validation/Validation.cpp15
-rw-r--r--compute_kernel_writer/validation/tests/CLKernelWriterAssignTest.h101
-rw-r--r--compute_kernel_writer/validation/tests/CLKernelWriterCastTest.h104
-rw-r--r--compute_kernel_writer/validation/tests/CLKernelWriterCommentTest.h16
-rw-r--r--compute_kernel_writer/validation/tests/CLKernelWriterDeclareTileTest.h2
-rw-r--r--compute_kernel_writer/validation/tests/CLKernelWriterUnaryExpressionTest.h103
-rw-r--r--compute_kernel_writer/validation/tests/CLTileTest.hpp18
19 files changed, 795 insertions, 69 deletions
diff --git a/compute_kernel_writer/CMakeLists.txt b/compute_kernel_writer/CMakeLists.txt
index 783dd5e78b..a539ef7186 100644
--- a/compute_kernel_writer/CMakeLists.txt
+++ b/compute_kernel_writer/CMakeLists.txt
@@ -118,6 +118,7 @@ target_compile_definitions(ckw PUBLIC
)
target_sources(ckw PRIVATE
+ src/types/DataTypeHelpers.cpp
src/Error.cpp
src/Helpers.cpp
src/Kernel.cpp
diff --git a/compute_kernel_writer/include/ckw/Error.h b/compute_kernel_writer/include/ckw/Error.h
index eaf3f10c05..7da9544b9e 100644
--- a/compute_kernel_writer/include/ckw/Error.h
+++ b/compute_kernel_writer/include/ckw/Error.h
@@ -113,6 +113,13 @@ inline void ignore_unused(T &&...)
*/
#define CKW_ASSERT(cond) CKW_ASSERT_MSG(cond, #cond)
+/** If the precondition is met but the condition is not met, throw an std::runtime_error if assertion is enabled.
+ *
+ * @param[in] precond The precondition that triggers the check.
+ * @param[in] cond The condition that is expected to be true if precondition is true.
+ */
+#define CKW_ASSERT_IF(precond, cond) CKW_ASSERT(!(precond) || (cond))
+
/** Throw an std::runtime_error with the specified message if assertion is enabled.
*
* @param[in] msg The error message when the condition is not met.
diff --git a/compute_kernel_writer/include/ckw/KernelWriter.h b/compute_kernel_writer/include/ckw/KernelWriter.h
index f77798e2ab..7eb6d2894a 100644
--- a/compute_kernel_writer/include/ckw/KernelWriter.h
+++ b/compute_kernel_writer/include/ckw/KernelWriter.h
@@ -27,7 +27,10 @@
#include "ckw/TensorOperand.h"
#include "ckw/TileOperand.h"
+#include "ckw/types/ConvertPolicy.h"
+#include "ckw/types/Operators.h"
+#include <functional>
#include <memory>
#include <string>
@@ -76,6 +79,33 @@ public:
virtual ~KernelWriter();
// =============================================================================================
+ // Data processing
+ // =============================================================================================
+
+ /** Write assignment statement: `<dst> = <src>;`.
+ *
+ * @param[in] dst The destination tile.
+ * @param[in] src The source tile.
+ */
+ virtual void op_assign(const TileOperand &dst, const TileOperand &src) = 0;
+
+ /** Write the cast statement: `<dst> = convert_<dst.type><policy>(<src>);`.
+ *
+ * @param[in] dst The destination tile.
+ * @param[in] src The source tile.
+ * @param[in] policy The policy governing the behavior of the cast.
+ */
+ virtual void op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy) = 0;
+
+ /** Write the unary expression statement: `<dst> = <op> <src>;`.
+ *
+ * @param[in] dst The destination tile.
+ * @param[in] src The source tile.
+ * @param[in] op The unary operator.
+ */
+ virtual void op_unary(const TileOperand &dst, const TileOperand &src, UnaryOp op) = 0;
+
+ // =============================================================================================
// Misc
// =============================================================================================
@@ -87,7 +117,16 @@ public:
*
* @param[in] text The comment to be written.
*/
- virtual void comment(const std::string &text) = 0;
+ virtual void op_comment(const std::string &text) = 0;
+
+ /** Write the given raw code to kernel source code
+ * It's used to address the cases where the user needs to
+ * explicitly add a code where it's not (yet) supported by
+ * the kernel writer utility calls.
+ *
+ * @param[in] raw_code raw code to write as string
+ */
+ virtual void op_write_raw_code(const std::string &raw_code) = 0;
// =============================================================================================
// Code generation
@@ -121,15 +160,6 @@ public:
*/
virtual TileOperand declare_tile(const std::string &name, const TileInfo &tile_info) = 0;
- /** Write the given raw code to kernel source code
- * It's used to address the cases where the user needs to
- * explicitly add a code where it's not (yet) supported by
- * the kernel writer utility calls.
- *
- * @param[in] raw_code raw code to write as string
- */
- virtual void op_write_raw_code(const std::string &raw_code) = 0;
-
/** Load the data from the tensor memory to the tile using the sampling information.
*
* @param[in] tile_op The tile to be loaded.
@@ -140,7 +170,8 @@ public:
* @param[in] z z-coordinate
* @param[in] batch batch offset
*/
- virtual void op_load(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
+ virtual void op_load(
+ const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) = 0;
/** Load the data from the tensor memory to the tile in a dilated way using the sampling information.
@@ -150,7 +181,8 @@ public:
* @param[in] dilation_x Dilation while reading in x-dimension
* @param[in] dilation_y Dilation while reading in y-dimension
*/
- virtual void op_load_dilated(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
+ virtual void op_load_dilated(
+ const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
const TileOperand &dilation_x, const TileOperand &dilation_y) = 0;
@@ -158,14 +190,16 @@ public:
*
* Similar to @ref KernelWriter::op_load()
*/
- virtual void op_store(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
+ virtual void op_store(
+ const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) = 0;
/** Store the data to the tensor memory from the tile in a dilated way using the sampling information.
*
* Similar to @ref KernelWriter::op_load_dilated()
*/
- virtual void op_store_dilated(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
+ virtual void op_store_dilated(
+ const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
const TileOperand &dilation_x, const TileOperand &dilation_y) = 0;
diff --git a/compute_kernel_writer/include/ckw/types/ConvertPolicy.h b/compute_kernel_writer/include/ckw/types/ConvertPolicy.h
new file mode 100644
index 0000000000..43a37ff118
--- /dev/null
+++ b/compute_kernel_writer/include/ckw/types/ConvertPolicy.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_INCLUDE_CKW_TYPES_CONVERTPOLICY_H
+#define CKW_INCLUDE_CKW_TYPES_CONVERTPOLICY_H
+
+#include <cstdint>
+
+namespace ckw
+{
+
+enum class ConvertPolicy : int32_t
+{
+ None = 0, // No policy specified.
+ Saturate = 1, // Saturated.
+};
+
+} // namespace ckw
+
+#endif // CKW_INCLUDE_CKW_TYPES_CONVERTPOLICY_H
diff --git a/compute_kernel_writer/include/ckw/types/Operators.h b/compute_kernel_writer/include/ckw/types/Operators.h
new file mode 100644
index 0000000000..ec2df08c46
--- /dev/null
+++ b/compute_kernel_writer/include/ckw/types/Operators.h
@@ -0,0 +1,57 @@
+/*
+* Copyright (c) 2023 Arm Limited.
+*
+* SPDX-License-Identifier: MIT
+*
+* Permission is hereby granted, free of charge, to any person obtaining a copy
+* of this software and associated documentation files (the "Software"), to
+* deal in the Software without restriction, including without limitation the
+* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+* sell copies of the Software, and to permit persons to whom the Software is
+* furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in all
+* copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+* SOFTWARE.
+*/
+
+#ifndef CKW_INCLUDE_CKW_TYPES_OPERATORS_H
+#define CKW_INCLUDE_CKW_TYPES_OPERATORS_H
+
+#include <cstdint>
+
+namespace ckw
+{
+
+/** Unary operators and functions. */
+enum class UnaryOp : int32_t
+{
+ LogicalNot = 0x0000, // !
+ BitwiseNot = 0x0001, // ~
+
+ Exp = 0x0010,
+ Tanh = 0x0011,
+ Sqrt = 0x0012,
+ Erf = 0x0013,
+ Fabs = 0x0014,
+ Log = 0x0015,
+ Round = 0x0016,
+};
+
+/** Assignment operators. */
+enum class AssignmentOp : int32_t
+{
+ Increment = 0x0000, // +=
+ Decrement = 0x0001, // -=
+};
+
+} // namespace ckw
+
+#endif // CKW_INCLUDE_CKW_TYPES_OPERATORS_H
diff --git a/compute_kernel_writer/src/cl/CLHelpers.cpp b/compute_kernel_writer/src/cl/CLHelpers.cpp
index 08108e383f..f62e1c28e6 100644
--- a/compute_kernel_writer/src/cl/CLHelpers.cpp
+++ b/compute_kernel_writer/src/cl/CLHelpers.cpp
@@ -21,10 +21,13 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
#include "src/cl/CLHelpers.h"
+
#include "ckw/Error.h"
#include "ckw/types/DataType.h"
#include "ckw/types/TensorStorageType.h"
+#include "src/types/DataTypeHelpers.h"
namespace ckw
{
@@ -142,10 +145,46 @@ std::string cl_get_variable_storagetype_as_string(TensorStorageType storage)
return res;
}
+std::tuple<bool, std::string> cl_get_unary_op(UnaryOp op)
+{
+ switch(op)
+ {
+ case UnaryOp::LogicalNot:
+ return { false, "!" };
+
+ case UnaryOp::BitwiseNot:
+ return { false, "~" };
+
+ case UnaryOp::Exp:
+ return { true, "exp" };
+
+ case UnaryOp::Tanh:
+ return { true, "tanh" };
+
+ case UnaryOp::Sqrt:
+ return { true, "sqrt" };
+
+ case UnaryOp::Erf:
+ return { true, "erf" };
+
+ case UnaryOp::Fabs:
+ return { true, "fabs" };
+
+ case UnaryOp::Log:
+ return { true, "log" };
+
+ case UnaryOp::Round:
+ return { true, "round" };
+
+ default:
+ CKW_THROW_MSG("Unsupported unary operation!");
+ }
+}
+
std::string cl_data_type_rounded_up_to_valid_vector_width(DataType dt, int32_t width)
{
- std::string data_type;
- const int32_t w = cl_round_up_to_nearest_valid_vector_width(width);
+ std::string data_type;
+ const int32_t w = cl_round_up_to_nearest_valid_vector_width(width);
data_type += cl_get_variable_datatype_as_string(dt, 1);
if(w != 1)
{
diff --git a/compute_kernel_writer/src/cl/CLHelpers.h b/compute_kernel_writer/src/cl/CLHelpers.h
index 669424088e..3c1a7724e2 100644
--- a/compute_kernel_writer/src/cl/CLHelpers.h
+++ b/compute_kernel_writer/src/cl/CLHelpers.h
@@ -24,8 +24,11 @@
#ifndef CKW_SRC_CL_CLHELPERS_H
#define CKW_SRC_CL_CLHELPERS_H
+#include "ckw/types/Operators.h"
+
#include <cstdint>
#include <string>
+#include <tuple>
#include <vector>
/** OpenCL specific helper functions */
@@ -52,6 +55,24 @@ bool cl_validate_vector_length(int32_t len);
*/
std::string cl_get_variable_datatype_as_string(DataType dt, int32_t len);
+/** Return the assignment operator in OpenCL language.
+ *
+ * @param[in] op The assignment operator.
+ *
+ * @return The operator in OpenCL language as a string.
+ */
+std::string cl_get_assignment_op_as_string(AssignmentOp op);
+
+/** Return the information about the unary operation.
+ *
+ * The result contains:
+ * - is_func: true if it's a function and false if it's an unary operator in OpenCL language.
+ * - str: the function name or the operator in OpenCL language.
+ *
+ * @param[in] op The unary operator.
+ */
+std::tuple<bool, std::string> cl_get_unary_op(UnaryOp op);
+
/** Helper function to return the OpenCL vector size that accommodate the the desired width
*
* @param[in] width The desired width
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
index b4df5c5f50..33d16da926 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
@@ -23,6 +23,7 @@
*/
#include "src/cl/CLKernelWriter.h"
+
#include "ckw/Error.h"
#include "ckw/Kernel.h"
#include "ckw/TensorSampler.h"
@@ -37,6 +38,9 @@
#include "src/cl/helpers/CLMemoryOpImage2dHelper.h"
#include "src/cl/helpers/ICLMemoryOpHelper.h"
+#include "src/types/DataTypeHelpers.h"
+
+#include <algorithm>
#include <cstdint>
namespace ckw
@@ -106,7 +110,95 @@ std::unique_ptr<Kernel> CLKernelWriter::emit_kernel(const std::string &name)
return std::make_unique<Kernel>(TargetLanguage::OpenCL, arguments, code);
}
-void CLKernelWriter::comment(const std::string &text)
+void CLKernelWriter::op_assign(const TileOperand &dst, const TileOperand &src)
+{
+ const auto &dst_tile = to_cl_tile(dst);
+ const auto &src_tile = to_cl_tile(src);
+
+ const auto dst_w = dst_tile.info().width();
+ const auto dst_h = dst_tile.info().height();
+ const auto src_w = src_tile.info().width();
+
+ const auto data_type_str = cl_get_variable_datatype_as_string(dst_tile.info().data_type(), dst_w);
+
+ const auto broadcast_src_x = dst_w != 1 && src_w == 1;
+ const std::string src_prefix = broadcast_src_x ? "(" + data_type_str + ")" : "";
+
+ CKW_ASSERT_MSG(src_tile.info().data_type() == dst_tile.info().data_type(), "Source and destination type must match.");
+ CKW_ASSERT_MSG(src_tile.info().height() == dst_h || src_tile.info().height() == 1, "Tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
+
+ // Broadcasting on y dimension is automatic (see CLTile::vector).
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ append_code(dst_tile.vector(y).str, " = ", src_prefix, src_tile.vector(y).str, ";\n");
+ }
+}
+
+void CLKernelWriter::op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy)
+{
+ const auto &dst_tile = to_cl_tile(dst);
+ const auto &src_tile = to_cl_tile(src);
+
+ const auto dst_w = dst_tile.info().width();
+ const auto dst_h = dst_tile.info().height();
+ const auto src_w = src_tile.info().width();
+
+ const auto dst_type = dst_tile.info().data_type();
+
+ const auto convert_type_str = cl_get_variable_datatype_as_string(dst_type, src_w);
+ const auto dst_type_str = cl_get_variable_datatype_as_string(dst_type, dst_w);
+
+ const std::string sat = policy == ConvertPolicy::Saturate ? "_sat" : "";
+ CKW_ASSERT_IF(policy == ConvertPolicy::Saturate, !is_data_type_float(dst_type));
+
+ const auto broadcast_x = dst_w != 1 && src_w == 1;
+ const std::string prefix = broadcast_x ? "(" + dst_type_str + ")" : "";
+
+ CKW_ASSERT_MSG(src_tile.info().data_type() != dst_tile.info().data_type(), "Source and destination type must be different.");
+ CKW_ASSERT_MSG(src_tile.info().height() == dst_h || src_tile.info().height() == 1, "Tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
+
+ // Broadcasting on y dimension is automatic (see CLTile::vector).
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ append_code(dst_tile.vector(y).str, " = ", prefix, "convert_", convert_type_str, sat, "(", src_tile.vector(y).str, ");\n");
+ }
+}
+
+void CLKernelWriter::op_unary(const TileOperand &dst, const TileOperand &src, UnaryOp op)
+{
+ const auto &dst_tile = to_cl_tile(dst);
+ const auto &src_tile = to_cl_tile(src);
+
+ const auto dst_w = dst_tile.info().width();
+ const auto dst_h = dst_tile.info().height();
+ const auto src_w = src_tile.info().width();
+
+ const auto data_type_str = cl_get_variable_datatype_as_string(dst_tile.info().data_type(), dst_w);
+ const auto broadcast_src_x = dst_w != 1 && src_w == 1;
+
+ const std::string src_prefix = broadcast_src_x ? "(" + data_type_str + ")" : "";
+
+ const auto op_info = cl_get_unary_op(op);
+ const auto op_is_func = std::get<0>(op_info);
+ const auto &op_name = std::get<1>(op_info);
+ const auto op_prefix = op_is_func ? op_name + "(" : op_name;
+ const auto op_suffix = op_is_func ? ")" : "";
+
+ CKW_ASSERT_MSG(src_tile.info().data_type() == dst_tile.info().data_type(), "Source and destination type must match.");
+ CKW_ASSERT_MSG(!is_data_type_float(src_tile.info().data_type()), "Logical and bitwise not only work with integer.");
+ CKW_ASSERT_MSG(src_tile.info().height() == dst_h || src_tile.info().height() == 1, "Tile height must match or source is broadcasting in y dimension.");
+ CKW_ASSERT_MSG(src_w == dst_w || src_w == 1, "Tile width must match or source is broadcasting in x dimension.");
+
+ // Broadcasting on y dimension is automatic (see CLTile::vector).
+ for(int32_t y = 0; y < dst_h; ++y)
+ {
+ append_code(dst_tile.vector(y).str, " = ", src_prefix, op_prefix, src_tile.vector(y).str, op_suffix, ";\n");
+ }
+}
+
+void CLKernelWriter::op_comment(const std::string &text)
{
#ifdef COMPUTE_KERNEL_WRITER_DEBUG_ENABLED
@@ -147,13 +239,24 @@ TileOperand CLKernelWriter::declare_tile(const std::string &name, const TileInfo
const int32_t width = tile_info.width();
const DataType data_type = tile_info.data_type();
+ CKW_ASSERT_MSG(
+ std::find_if(
+ _tiles.begin(), _tiles.end(),
+ [=](const std::unique_ptr<CLTile> &e)
+ {
+ return e->name() == fullname;
+ })
+ == _tiles.end(),
+ "Tile name must be unique.");
+
+ auto tile = std::make_unique<CLTile>(fullname, tile_info);
+
for(int32_t row = 0; row < height; ++row)
{
const std::string cl_type = cl_get_variable_datatype_as_string(data_type, width);
- append_code(cl_type, " ", fullname, std::to_string(row), ";\n");
+ append_code(cl_type, " ", tile->vector(row).str, ";\n");
}
- auto tile = std::make_unique<CLTile>(name, tile_info);
const auto operand = create_tile_operand(*tile);
_tiles.insert(std::move(tile));
@@ -169,10 +272,12 @@ void CLKernelWriter::op_write_raw_code(const std::string &raw_code)
const CLTile &CLKernelWriter::to_cl_tile(const TileOperand &operand)
{
const auto &tile = get_tile(operand);
+
#ifdef COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED
// Check if the tile is a CLTile created by this kernel writer.
{
bool found = false;
+
for(const auto &t : _tiles)
{
if(&tile == t.get())
@@ -181,11 +286,13 @@ const CLTile &CLKernelWriter::to_cl_tile(const TileOperand &operand)
break;
}
}
+
if(!found)
{
for(const auto &t : _tensors)
{
const auto components = t->components();
+
for(const auto component : components)
{
if(&tile == &component->tile())
@@ -194,16 +301,23 @@ const CLTile &CLKernelWriter::to_cl_tile(const TileOperand &operand)
break;
}
}
+
+ if(found)
+ {
+ break;
+ }
}
}
+
CKW_ASSERT_MSG(found, "The tile is not found!");
}
#endif // COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED
+
return static_cast<const CLTile &>(tile);
}
void CLKernelWriter::op_load(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch)
+ const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch)
{
const CLTile dilation_x("1", DataType::Int32);
const CLTile dilation_y("1", DataType::Int32);
@@ -212,8 +326,8 @@ void CLKernelWriter::op_load(const TileOperand &tile_op, const TensorOperand &te
}
void CLKernelWriter::op_load_dilated(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
- const TileOperand &dilation_x, const TileOperand &dilation_y)
+ const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
+ const TileOperand &dilation_x, const TileOperand &dilation_y)
{
const auto &dil_x_tile = to_cl_tile(dilation_x);
const auto &dil_y_tile = to_cl_tile(dilation_y);
@@ -222,7 +336,7 @@ void CLKernelWriter::op_load_dilated(const TileOperand &tile_op, const TensorOpe
}
void CLKernelWriter::op_store(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch)
+ const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch)
{
const CLTile dilation_x("1", DataType::Int32);
const CLTile dilation_y("1", DataType::Int32);
@@ -231,8 +345,8 @@ void CLKernelWriter::op_store(const TensorOperand &tensor_op, const TileOperand
}
void CLKernelWriter::op_store_dilated(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
- const TileOperand &dilation_x, const TileOperand &dilation_y)
+ const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
+ const TileOperand &dilation_x, const TileOperand &dilation_y)
{
const auto &dil_x_tile = to_cl_tile(dilation_x);
const auto &dil_y_tile = to_cl_tile(dilation_y);
@@ -241,11 +355,11 @@ void CLKernelWriter::op_store_dilated(const TensorOperand &tensor_op, const Tile
}
void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
- const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
- const CLTile &dilation_x, const CLTile &dilation_y)
+ const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
+ const CLTile &dilation_x, const CLTile &dilation_y)
{
CKW_UNUSED(dilation_x);
- CKW_ASSERT(dilation_x.scalar(0,0).str == "1"); // Dilation in x dimension is not implemented yet
+ CKW_ASSERT(dilation_x.scalar(0, 0).str == "1"); // Dilation in x dimension is not implemented yet
ITensor &tensor = get_tensor(tensor_op);
@@ -263,10 +377,10 @@ void CLKernelWriter::op_load_store(MemoryOperation op, const TileOperand &tile_o
CKW_THROW_MSG("Unsupported tensor storage");
}
- const auto &tile = to_cl_tile(tile_op);
- const auto &x_tile = to_cl_tile(x);
- const auto &y_tile = to_cl_tile(y);
- const auto &z_tile = to_cl_tile(z);
+ const auto &tile = to_cl_tile(tile_op);
+ const auto &x_tile = to_cl_tile(x);
+ const auto &y_tile = to_cl_tile(y);
+ const auto &z_tile = to_cl_tile(z);
const auto &batch_tile = to_cl_tile(batch);
helper->initialize(&tile, &x_tile, &z_tile, &batch_tile);
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.h b/compute_kernel_writer/src/cl/CLKernelWriter.h
index a40698d7bb..ea455a7fdd 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.h
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.h
@@ -57,13 +57,21 @@ public:
~CLKernelWriter();
// =============================================================================================
+ // Data processing
+ // =============================================================================================
+
+ void op_assign(const TileOperand &dst, const TileOperand &src) override;
+
+ void op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy) override;
+
+ void op_unary(const TileOperand &dst, const TileOperand &src, UnaryOp op) override;
+
+ // =============================================================================================
// Misc
// =============================================================================================
- /** Similar to @ref KernelWriter::comment() */
- void comment(const std::string &text) override;
+ void op_comment(const std::string &text) override;
- /** Similar to @ref KernelWriter::op_write_raw_code() */
void op_write_raw_code(const std::string &raw_code) override;
// =============================================================================================
@@ -92,14 +100,16 @@ public:
*
* Similar to @ref KernelWriter::op_load()
*/
- void op_load(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
+ void op_load(
+ const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override;
/** Load the data from the tensor memory to the tile in a dilated way using the sampling information.
*
* Similar to @ref KernelWriter::op_load_dilated()
*/
- void op_load_dilated(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
+ void op_load_dilated(
+ const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
const TileOperand &dilation_x, const TileOperand &dilation_y) override;
@@ -107,18 +117,26 @@ public:
*
* Similar to @ref KernelWriter::op_store()
*/
- void op_store(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
+ void op_store(
+ const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) override;
/** Store the data to the tensor memory from the tile in a dilated way using the sampling information.
*
* Similar to @ref KernelWriter::op_store_dilated()
*/
- void op_store_dilated(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
+ void op_store_dilated(
+ const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler,
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
const TileOperand &dilation_x, const TileOperand &dilation_y) override;
protected:
+ /** Return @ref CLTile object from the @ref TileOperand object.
+ *
+ * This function performs appropriate check before doing type casting.
+ */
+ const CLTile &to_cl_tile(const TileOperand &operand);
+
/** Append the specified code to the kernel body source code. */
template <typename T, typename... TArgs>
void append_code(T &&code, TArgs &&...args)
@@ -137,20 +155,15 @@ protected:
/** Get the current kernel body source code. */
const std::string &body_source_code() const;
-// For helper functions
+ // For helper functions
private:
- /** Return @ref CLTile object from the @ref TileOperand object.
- *
- * This function performs appropriate check before doing type casting.
- */
- const CLTile &to_cl_tile(const TileOperand &operand);
-
/** Helper function to consolidate all load/store logic in this class */
- void op_load_store(MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
+ void op_load_store(
+ MemoryOperation op, const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler,
const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch,
const CLTile &dilation_x, const CLTile &dilation_y);
-// For attributes
+ // For attributes
private:
/** This string contains the kernel body source code, not the full CL source code.
* The full source code will only be generated when the user calls @ref KernelWriter::emit_kernel.
diff --git a/compute_kernel_writer/src/cl/CLTile.cpp b/compute_kernel_writer/src/cl/CLTile.cpp
index 013ac4c276..556db0f47b 100644
--- a/compute_kernel_writer/src/cl/CLTile.cpp
+++ b/compute_kernel_writer/src/cl/CLTile.cpp
@@ -210,7 +210,7 @@ std::string CLTile::create_var_name(int32_t row) const
// If a scalar variable, we do not append the row index
if(_info.height() > 1)
{
- var_name += "_";
+ var_name += "__";
var_name += std::to_string(row);
}
@@ -229,4 +229,4 @@ void CLTile::validate_tile_info(const TileInfo &info) const
CKW_ASSERT_MSG(info.data_type() != DataType::Unknown, "DataType::Unknown is not supported");
}
-} // namespace ckw \ No newline at end of file
+} // namespace ckw
diff --git a/compute_kernel_writer/src/types/DataTypeHelpers.cpp b/compute_kernel_writer/src/types/DataTypeHelpers.cpp
new file mode 100644
index 0000000000..7f0c33fb72
--- /dev/null
+++ b/compute_kernel_writer/src/types/DataTypeHelpers.cpp
@@ -0,0 +1,35 @@
+/*
+* Copyright (c) 2023 Arm Limited.
+*
+* SPDX-License-Identifier: MIT
+*
+* Permission is hereby granted, free of charge, to any person obtaining a copy
+* of this software and associated documentation files (the "Software"), to
+* deal in the Software without restriction, including without limitation the
+* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+* sell copies of the Software, and to permit persons to whom the Software is
+* furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in all
+* copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+* SOFTWARE.
+*/
+
+#include "src/types/DataTypeHelpers.h"
+
+namespace ckw
+{
+
+bool is_data_type_float(DataType data_type)
+{
+ return (data_type == DataType::Fp32 || data_type == DataType::Fp16);
+}
+
+} // namespace ckw
diff --git a/compute_kernel_writer/src/types/DataTypeHelpers.h b/compute_kernel_writer/src/types/DataTypeHelpers.h
new file mode 100644
index 0000000000..b6ec6ccd19
--- /dev/null
+++ b/compute_kernel_writer/src/types/DataTypeHelpers.h
@@ -0,0 +1,43 @@
+/*
+* Copyright (c) 2023 Arm Limited.
+*
+* SPDX-License-Identifier: MIT
+*
+* Permission is hereby granted, free of charge, to any person obtaining a copy
+* of this software and associated documentation files (the "Software"), to
+* deal in the Software without restriction, including without limitation the
+* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+* sell copies of the Software, and to permit persons to whom the Software is
+* furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in all
+* copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+* SOFTWARE.
+*/
+
+#ifndef CKW_SRC_TYPES_DATATYPEHELPERS_H
+#define CKW_SRC_TYPES_DATATYPEHELPERS_H
+
+#include "ckw/types/DataType.h"
+
+namespace ckw
+{
+
+/** Return a value indicating whether the data type is floating-point.
+ *
+ * @param[in] data_type The data type to check.
+ *
+ * @return Whether the data type is floating-point.
+ */
+bool is_data_type_float(DataType data_type);
+
+} // namespace ckw
+
+#endif // CKW_SRC_TYPES_DATATYPEHELPERS_H
diff --git a/compute_kernel_writer/validation/Validation.cpp b/compute_kernel_writer/validation/Validation.cpp
index 3755986cf4..c55c7c0c07 100644
--- a/compute_kernel_writer/validation/Validation.cpp
+++ b/compute_kernel_writer/validation/Validation.cpp
@@ -23,14 +23,17 @@
*/
#include "validation/tests/CLConstantTileTest.hpp"
+#include "validation/tests/CLKernelWriterAssignTest.h"
+#include "validation/tests/CLKernelWriterCastTest.h"
#include "validation/tests/CLKernelWriterCommentTest.h"
+#include "validation/tests/CLKernelWriterDeclareTensorTest.h"
#include "validation/tests/CLKernelWriterDeclareTileTest.h"
+#include "validation/tests/CLKernelWriterOpLoadStoreTest.h"
+#include "validation/tests/CLKernelWriterUnaryExpressionTest.h"
#include "validation/tests/CLTensorArgumentTest.h"
#include "validation/tests/CLTileTest.hpp"
#include "validation/tests/TensorBitMaskTest.h"
#include "validation/tests/UtilsTest.h"
-#include "validation/tests/CLKernelWriterDeclareTensorTest.h"
-#include "validation/tests/CLKernelWriterOpLoadStoreTest.h"
#include <memory>
#include <vector>
@@ -77,6 +80,9 @@ int32_t main()
const auto test23 = std::make_unique<CLTensorArgumentComponentsUsedPassByValueTrueDynamicDimTrueTest>();
const auto test24 = std::make_unique<CLKernelWriterDeclareTensorTest>();
const auto test25 = std::make_unique<CLKernelWriterOpLoadStoreTest>();
+ const auto test26 = std::make_unique<CLKernelWriterAssignTest>();
+ const auto test27 = std::make_unique<CLKernelWriterCastTest>();
+ const auto test28 = std::make_unique<CLKernelWriterUnaryExpressionTest>();
tests.push_back(test3.get());
tests.push_back(test4.get());
@@ -102,7 +108,10 @@ int32_t main()
tests.push_back(test22.get());
tests.push_back(test23.get());
tests.push_back(test24.get());
- tests.push_back(test25.get());
+ CKW_UNUSED(test25); // CLKernelWriterOpLoadStoreTest test needs further changes.
+ tests.push_back(test26.get());
+ tests.push_back(test27.get());
+ tests.push_back(test28.get());
#endif /* COMPUTE_KERNEL_WRITER_OPENCL_ENABLED */
bool all_test_passed = true;
diff --git a/compute_kernel_writer/validation/tests/CLKernelWriterAssignTest.h b/compute_kernel_writer/validation/tests/CLKernelWriterAssignTest.h
new file mode 100644
index 0000000000..f32f797a01
--- /dev/null
+++ b/compute_kernel_writer/validation/tests/CLKernelWriterAssignTest.h
@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_VALIDATION_TESTS_CLKERNELWRITERASSIGNTEST_H
+#define CKW_VALIDATION_TESTS_CLKERNELWRITERASSIGNTEST_H
+
+#include "ckw/TileInfo.h"
+#include "ckw/types/DataType.h"
+#include "src/cl/CLKernelWriter.h"
+#include "validation/tests/common/Common.h"
+#include "validation/tests/common/KernelWriterInterceptor.h"
+
+#include <cstdint>
+#include <vector>
+
+namespace ckw
+{
+
+class CLKernelWriterAssignTest : public ITest
+{
+public:
+ CLKernelWriterAssignTest()
+ {
+ _tests.push_back({ 1, 1, 1, 1, DataType::Fp32, "G0__dst = G0__src;\n" }); // Scalar.
+
+ _tests.push_back({ 1, 3, 1, 3, DataType::Fp16, "G0__dst = G0__src;\n" }); // Whole vector.
+
+ _tests.push_back({ 2, 4, 2, 4, DataType::Int8, "G0__dst__0 = G0__src__0;\nG0__dst__1 = G0__src__1;\n" }); // Whole tile.
+
+ _tests.push_back({ 2, 3, 1, 3, DataType::Uint8, "G0__dst__0 = G0__src;\nG0__dst__1 = G0__src;\n" }); // Y-dimension broadcast.
+
+ _tests.push_back({ 2, 4, 2, 1, DataType::Fp32, "G0__dst__0 = (float4)G0__src__0;\nG0__dst__1 = (float4)G0__src__1;\n" }); // X-dimension broadcast.
+
+ _tests.push_back({ 2, 3, 1, 1, DataType::Fp16, "G0__dst__0 = (half3)G0__src;\nG0__dst__1 = (half3)G0__src;\n" }); // X and y dimension broadcast.
+ }
+
+ bool run() override
+ {
+ int32_t test_no = 0;
+ bool all_tests_passed = true;
+
+ for(const auto &test : _tests)
+ {
+ KernelWriterInterceptor<CLKernelWriter> writer;
+
+ auto dst = writer.declare_tile("dst", TileInfo(test.data_type, test.dst_height, test.dst_width));
+ auto src = writer.declare_tile("src", TileInfo(test.data_type, test.src_height, test.src_width));
+
+ writer.start_capture_code();
+
+ writer.op_assign(dst, src);
+
+ VALIDATE_TEST(writer.check_added_code(test.expected_code), all_tests_passed, test_no++);
+ }
+
+ return all_tests_passed;
+ }
+
+ std::string name() override
+ {
+ return "CLKernelWriterAssignTest";
+ }
+
+private:
+ struct TestInfo
+ {
+ int32_t dst_height;
+ int32_t dst_width;
+ int32_t src_height;
+ int32_t src_width;
+ DataType data_type;
+ std::string expected_code;
+ };
+
+ std::vector<TestInfo> _tests{};
+};
+
+} // namespace ckw
+
+#endif // CKW_VALIDATION_TESTS_CLKERNELWRITERASSIGNTEST_H
diff --git a/compute_kernel_writer/validation/tests/CLKernelWriterCastTest.h b/compute_kernel_writer/validation/tests/CLKernelWriterCastTest.h
new file mode 100644
index 0000000000..a185cce545
--- /dev/null
+++ b/compute_kernel_writer/validation/tests/CLKernelWriterCastTest.h
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_VALIDATION_TESTS_CLKERNELWRITERCASTTEST_H
+#define CKW_VALIDATION_TESTS_CLKERNELWRITERCASTTEST_H
+
+#include "ckw/TileInfo.h"
+#include "ckw/types/ConvertPolicy.h"
+#include "ckw/types/DataType.h"
+#include "src/cl/CLKernelWriter.h"
+#include "validation/tests/common/Common.h"
+#include "validation/tests/common/KernelWriterInterceptor.h"
+
+#include <cstdint>
+#include <vector>
+
+namespace ckw
+{
+
+class CLKernelWriterCastTest : public ITest
+{
+public:
+ CLKernelWriterCastTest()
+ {
+ _tests.push_back({ 1, 1, DataType::Fp16, 1, 1, DataType::Fp32, ConvertPolicy::None, "G0__dst = convert_half(G0__src);\n" }); // Scalar.
+
+ _tests.push_back({ 1, 3, DataType::Int32, 1, 3, DataType::Fp16, ConvertPolicy::Saturate, "G0__dst = convert_int3_sat(G0__src);\n" }); // Whole vector.
+
+ _tests.push_back({ 2, 4, DataType::Uint16, 2, 4, DataType::Int8, ConvertPolicy::Saturate, "G0__dst__0 = convert_ushort4_sat(G0__src__0);\nG0__dst__1 = convert_ushort4_sat(G0__src__1);\n" }); // Whole tile.
+
+ _tests.push_back({ 2, 3, DataType::Int8, 1, 3, DataType::Uint8, ConvertPolicy::None, "G0__dst__0 = convert_char3(G0__src);\nG0__dst__1 = convert_char3(G0__src);\n" }); // Y-dimension broadcast.
+
+ _tests.push_back({ 2, 4, DataType::Fp16, 2, 1, DataType::Fp32, ConvertPolicy::None, "G0__dst__0 = (half4)convert_half(G0__src__0);\nG0__dst__1 = (half4)convert_half(G0__src__1);\n" }); // X-dimension broadcast.
+
+ _tests.push_back({ 2, 3, DataType::Fp32, 1, 1, DataType::Fp16, ConvertPolicy::None, "G0__dst__0 = (float3)convert_float(G0__src);\nG0__dst__1 = (float3)convert_float(G0__src);\n" }); // X and y dimension broadcast.
+ }
+
+ bool run() override
+ {
+ int32_t test_no = 0;
+ bool all_tests_passed = true;
+
+ for(const auto &test : _tests)
+ {
+ KernelWriterInterceptor<CLKernelWriter> writer;
+
+ auto dst = writer.declare_tile("dst", TileInfo(test.dst_data_type, test.dst_height, test.dst_width));
+ auto src = writer.declare_tile("src", TileInfo(test.src_data_type, test.src_height, test.src_width));
+
+ writer.start_capture_code();
+
+ writer.op_cast(dst, src, test.policy);
+
+ VALIDATE_TEST(writer.check_added_code(test.expected_code), all_tests_passed, test_no++);
+ }
+
+ return all_tests_passed;
+ }
+
+ std::string name() override
+ {
+ return "CLKernelWriterCastTest";
+ }
+
+private:
+ struct TestInfo
+ {
+ int32_t dst_height;
+ int32_t dst_width;
+ DataType dst_data_type;
+ int32_t src_height;
+ int32_t src_width;
+ DataType src_data_type;
+ ConvertPolicy policy;
+ std::string expected_code;
+ };
+
+ std::vector<TestInfo> _tests{};
+};
+
+} // namespace ckw
+
+#endif // CKW_VALIDATION_TESTS_CLKERNELWRITERCASTTEST_H
diff --git a/compute_kernel_writer/validation/tests/CLKernelWriterCommentTest.h b/compute_kernel_writer/validation/tests/CLKernelWriterCommentTest.h
index ff09ea8073..b36c3905ec 100644
--- a/compute_kernel_writer/validation/tests/CLKernelWriterCommentTest.h
+++ b/compute_kernel_writer/validation/tests/CLKernelWriterCommentTest.h
@@ -22,8 +22,8 @@
* SOFTWARE.
*/
-#ifndef CKW_VALIDATION_TESTS_CLKERNELTEST_H
-#define CKW_VALIDATION_TESTS_CLKERNELTEST_H
+#ifndef CKW_VALIDATION_TESTS_CLKERNELWRITERCOMMENTTEST_H
+#define CKW_VALIDATION_TESTS_CLKERNELWRITERCOMMENTTEST_H
#include "src/cl/CLKernelWriter.h"
#include "validation/tests/common/Common.h"
@@ -45,14 +45,18 @@ public:
KernelWriterInterceptor<CLKernelWriter> writer;
- writer.comment("previous code");
+ writer.op_comment("previous code");
writer.start_capture_code();
- writer.comment("code under test 0");
- writer.comment("code under test 1");
+ writer.op_comment("code under test 0");
+ writer.op_comment("code under test 1");
+#ifdef COMPUTE_KERNEL_WRITER_DEBUG_ENABLED
constexpr auto expected_code = "// code under test 0\n// code under test 1\n";
+#else // COMPUTE_KERNEL_WRITER_DEBUG_ENABLED
+ constexpr auto expected_code = "";
+#endif // COMPUTE_KERNEL_WRITER_DEBUG_ENABLED
VALIDATE_TEST(writer.check_added_code(expected_code), all_tests_passed, 0);
@@ -67,4 +71,4 @@ public:
} // namespace ckw
-#endif // CKW_VALIDATION_TESTS_CLKERNELTEST_H
+#endif // CKW_VALIDATION_TESTS_CLKERNELWRITERCOMMENTTEST_H
diff --git a/compute_kernel_writer/validation/tests/CLKernelWriterDeclareTileTest.h b/compute_kernel_writer/validation/tests/CLKernelWriterDeclareTileTest.h
index 5e00084aaa..4f728bc1bf 100644
--- a/compute_kernel_writer/validation/tests/CLKernelWriterDeclareTileTest.h
+++ b/compute_kernel_writer/validation/tests/CLKernelWriterDeclareTileTest.h
@@ -73,7 +73,7 @@ public:
std::string expected_code = "";
for(int32_t row = 0; row < height; ++row)
{
- expected_code += prefix + std::to_string(row) + ";\n";
+ expected_code += prefix + ((height > 1) ? std::string("__") + std::to_string(row) : "") + ";\n";
}
TileInfo tile_info(data_type, height, width);
diff --git a/compute_kernel_writer/validation/tests/CLKernelWriterUnaryExpressionTest.h b/compute_kernel_writer/validation/tests/CLKernelWriterUnaryExpressionTest.h
new file mode 100644
index 0000000000..65440a0a99
--- /dev/null
+++ b/compute_kernel_writer/validation/tests/CLKernelWriterUnaryExpressionTest.h
@@ -0,0 +1,103 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_VALIDATION_TESTS_CLKERNELWRITERUNARYEXPRESSIONTEST_H
+#define CKW_VALIDATION_TESTS_CLKERNELWRITERUNARYEXPRESSIONTEST_H
+
+#include "ckw/TileInfo.h"
+#include "ckw/types/DataType.h"
+#include "ckw/types/Operators.h"
+#include "src/cl/CLKernelWriter.h"
+#include "validation/tests/common/Common.h"
+#include "validation/tests/common/KernelWriterInterceptor.h"
+
+#include <cstdint>
+#include <vector>
+
+namespace ckw
+{
+
+class CLKernelWriterUnaryExpressionTest : public ITest
+{
+public:
+ CLKernelWriterUnaryExpressionTest()
+ {
+ _tests.push_back({ 1, 1, 1, 1, DataType::Uint32, UnaryOp::BitwiseNot, "G0__dst = ~G0__src;\n" }); // Scalar.
+
+ _tests.push_back({ 1, 3, 1, 3, DataType::Int16, UnaryOp::LogicalNot, "G0__dst = !G0__src;\n" }); // Whole vector.
+
+ _tests.push_back({ 2, 4, 2, 4, DataType::Int8, UnaryOp::Exp, "G0__dst__0 = exp(G0__src__0);\nG0__dst__1 = exp(G0__src__1);\n" }); // Whole tile.
+
+ _tests.push_back({ 2, 3, 1, 3, DataType::Uint8, UnaryOp::Log, "G0__dst__0 = log(G0__src);\nG0__dst__1 = log(G0__src);\n" }); // Y-dimension broadcast.
+
+ _tests.push_back({ 2, 4, 2, 1, DataType::Uint16, UnaryOp::Sqrt, "G0__dst__0 = (ushort4)sqrt(G0__src__0);\nG0__dst__1 = (ushort4)sqrt(G0__src__1);\n" }); // X-dimension broadcast.
+
+ _tests.push_back({ 2, 3, 1, 1, DataType::Int32, UnaryOp::Round, "G0__dst__0 = (int3)round(G0__src);\nG0__dst__1 = (int3)round(G0__src);\n" }); // X and y dimension broadcast.
+ }
+
+ bool run() override
+ {
+ int32_t test_no = 0;
+ bool all_tests_passed = true;
+
+ for(const auto &test : _tests)
+ {
+ KernelWriterInterceptor<CLKernelWriter> writer;
+
+ auto dst = writer.declare_tile("dst", TileInfo(test.data_type, test.dst_height, test.dst_width));
+ auto src = writer.declare_tile("src", TileInfo(test.data_type, test.src_height, test.src_width));
+
+ writer.start_capture_code();
+
+ writer.op_unary(dst, src, test.op);
+
+ VALIDATE_TEST(writer.check_added_code(test.expected_code), all_tests_passed, test_no++);
+ }
+
+ return all_tests_passed;
+ }
+
+ std::string name() override
+ {
+ return "CLKernelWriterUnaryExpressionTest";
+ }
+
+private:
+ struct TestInfo
+ {
+ int32_t dst_height;
+ int32_t dst_width;
+ int32_t src_height;
+ int32_t src_width;
+ DataType data_type;
+ UnaryOp op;
+ std::string expected_code;
+ };
+
+ std::vector<TestInfo> _tests{};
+};
+
+} // namespace ckw
+
+#endif // CKW_VALIDATION_TESTS_CLKERNELWRITERUNARYEXPRESSIONTEST_H
diff --git a/compute_kernel_writer/validation/tests/CLTileTest.hpp b/compute_kernel_writer/validation/tests/CLTileTest.hpp
index ecfe811267..a95a11ace7 100644
--- a/compute_kernel_writer/validation/tests/CLTileTest.hpp
+++ b/compute_kernel_writer/validation/tests/CLTileTest.hpp
@@ -22,8 +22,8 @@
* SOFTWARE.
*/
-#ifndef COMPUTE_KERNEL_WRITER_TESTS_CLTILETEST_HPP
-#define COMPUTE_KERNEL_WRITER_TESTS_CLTILETEST_HPP
+#ifndef CKW_VALIDATION_TESTS_CLTILETEST_HPP
+#define CKW_VALIDATION_TESTS_CLTILETEST_HPP
#include "common/Common.h"
#include "src/Helpers.h"
@@ -63,7 +63,7 @@ public:
for(int32_t y = 0; y < height; ++y)
{
- const std::string expected_var_name = tile_name + "_" + std::to_string(y);
+ const std::string expected_var_name = tile_name + "__" + std::to_string(y);
const std::string actual_var_name = vars[y].str;
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
}
@@ -172,7 +172,7 @@ public:
const std::string actual_var_name = var.str;
std::string expected_var_name = tile_name;
- expected_var_name += "_" + std::to_string(y_coord);
+ expected_var_name += "__" + std::to_string(y_coord);
expected_var_name += ".s" + dec_to_hex_as_string(x_coord);
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
@@ -238,7 +238,7 @@ public:
const std::string actual_var_name = var.str;
std::string expected_var_name = tile_name;
- expected_var_name += "_" + std::to_string(y_coord);
+ expected_var_name += "__" + std::to_string(y_coord);
if(width != 1)
{
expected_var_name += ".s" + dec_to_hex_as_string(x_coord_clamped);
@@ -310,7 +310,7 @@ public:
std::string expected_var_name = tile_name;
if(height != 1)
{
- expected_var_name += "_" + std::to_string(y_coord_clamped);
+ expected_var_name += "__" + std::to_string(y_coord_clamped);
}
if(width != 1)
@@ -367,7 +367,7 @@ public:
std::string expected_var_name = tile_name;
if(height != 1)
{
- expected_var_name += "_" + std::to_string(row);
+ expected_var_name += "__" + std::to_string(row);
}
VALIDATE_TEST(actual_var_name.compare(expected_var_name) == 0, all_tests_passed, test_idx++);
@@ -423,7 +423,7 @@ public:
std::string expected_var_name = tile_name;
if(height != 1)
{
- expected_var_name += "_" + std::to_string(row);
+ expected_var_name += "__" + std::to_string(row);
}
if(width != 1)
@@ -464,4 +464,4 @@ private:
};
} // namespace ckw
-#endif /* COMPUTE_KERNEL_WRITER_TESTS_CLTILETEST_HPP */
+#endif // CKW_VALIDATION_TESTS_CLTILETEST_HPP