aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikolaj Jensen <nikolaj.jensen@arm.com>2023-06-27 14:13:24 +0100
committerNikolaj Jensen <nikolaj.jensen@arm.com>2023-07-26 10:18:54 +0000
commitfab6c210b37f1fa6b3e37a2583b18f8e4b5a4f12 (patch)
treebef025e8bee719fb6649eedc99150f5c066d1bc3
parent23882a9014eb3972bca958206866c8e0d0b829cc (diff)
downloadComputeLibrary-fab6c210b37f1fa6b3e37a2583b18f8e4b5a4f12.tar.gz
Design wrapper around CKW for easier writing
Signed-off-by: Nikolaj Jensen <nikolaj.jensen@arm.com> Change-Id: I114cdedcaf05c6abde046741837eeb73b813aa9d Signed-off-by: Nikolaj Jensen <nikolaj.jensen@arm.com> Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/532180 Tested-by: bsgcomp <bsgcomp@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Comments-Addressed: bsgcomp <bsgcomp@arm.com> Signed-off-by: Nikolaj Jensen <nikolaj.jensen@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9921 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--compute_kernel_writer/prototype/CMakeLists.txt3
-rw-r--r--compute_kernel_writer/prototype/examples/writer_helper.cpp118
-rw-r--r--compute_kernel_writer/prototype/include/ckw/KernelWriter.h5
-rw-r--r--compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h1268
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/Functions.h1
-rw-r--r--compute_kernel_writer/prototype/include/ckw/types/Operators.h4
-rw-r--r--compute_kernel_writer/prototype/src/KernelWriter.cpp5
-rw-r--r--compute_kernel_writer/prototype/src/Prototype.h10
8 files changed, 1401 insertions, 13 deletions
diff --git a/compute_kernel_writer/prototype/CMakeLists.txt b/compute_kernel_writer/prototype/CMakeLists.txt
index 3d6a192378..13d1ae8fc4 100644
--- a/compute_kernel_writer/prototype/CMakeLists.txt
+++ b/compute_kernel_writer/prototype/CMakeLists.txt
@@ -73,3 +73,6 @@ target_link_libraries(ckw_prototype_examples_common PUBLIC ckw_prototype)
add_executable(ckw_prototype_examples_add_exp_store examples/add_exp_store.cpp)
target_link_libraries(ckw_prototype_examples_add_exp_store PUBLIC ckw_prototype_examples_common)
+
+add_executable(writer_helper examples/writer_helper.cpp)
+target_link_libraries(writer_helper PUBLIC ckw_prototype)
diff --git a/compute_kernel_writer/prototype/examples/writer_helper.cpp b/compute_kernel_writer/prototype/examples/writer_helper.cpp
new file mode 100644
index 0000000000..ccef92dcdf
--- /dev/null
+++ b/compute_kernel_writer/prototype/examples/writer_helper.cpp
@@ -0,0 +1,118 @@
+/*
+* Copyright (c) 2023 Arm Limited.
+*
+* SPDX-License-Identifier: MIT
+*
+* Permission is hereby granted, free of charge, to any person obtaining a copy
+* of this software and associated documentation files (the "Software"), to
+* deal in the Software without restriction, including without limitation the
+* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+* sell copies of the Software, and to permit persons to whom the Software is
+* furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in all
+* copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+* SOFTWARE.
+*/
+
+#include "ckw/KernelWriter.h"
+#include "../include/ckw/KernelWriterHelper.h"
+#include "ckw/TensorTileSampler.h"
+
+#include <iostream>
+
+using namespace ckw;
+
+TensorTileSampler create_simple_sampler(KernelWriter& writer)
+{
+ TensorTileSampler sampler;
+
+ constexpr int32_t m0 = 1;
+ constexpr int32_t n0 = 1;
+
+ auto &gid_0 = writer.declare_tile("gid_0", DataType::Int32);
+ auto &gid_1 = writer.declare_tile("gid_1", DataType::Int32);
+ auto &gid_2 = writer.declare_tile("gid_2", DataType::Int32);
+
+ auto &const_0 = writer.declare_tile("0", 0);
+
+ writer.op_get_global_id(gid_0, 0);
+ writer.op_get_global_id(gid_1, 1);
+ writer.op_get_global_id(gid_2, 2);
+
+ sampler.x(gid_0);
+ sampler.y(gid_1);
+ sampler.z(gid_2);
+ sampler.b(const_0);
+
+ sampler.width(n0);
+ sampler.height(m0);
+
+ sampler.format(TensorSamplerFormat::C_WH_1);
+ sampler.address_mode_x(TensorSamplerAddressModeX::None);
+ sampler.address_mode_y(TensorSamplerAddressModeY::ClampToBorder);
+ sampler.address_mode_z(TensorSamplerAddressModeZ::Skip);
+
+ return sampler;
+}
+
+int main()
+{
+ Kernel kernel("test", GpuTargetLanguage::OpenCL);
+ KernelWriterHelper<KernelWriter> writer(kernel);
+
+ const TensorInfo src_info(DataType::Fp32, TensorShape({ 1, 1, 1, 1, 1 }), TensorDataLayout::Nhwc, 0);
+ const TensorInfo dst_info(DataType::Fp32, TensorShape({ 1, 1, 1, 1, 1 }), TensorDataLayout::Nhwc, 1);
+
+ auto &src_tensor = writer.declare_tensor_argument("src", src_info);
+ auto &dst_tensor = writer.declare_tensor_argument("dst", dst_info);
+
+ const auto sampler = create_simple_sampler(writer);
+
+ auto &src = writer.declare_tile("src_tile", TileInfo(src_tensor.data_type(), sampler.height(), sampler.width()));
+ auto &other = writer.declare_tile("other_tile", TileInfo(src_tensor.data_type(), sampler.height(), sampler.width()));
+ auto &dst = writer.declare_tile("dst_tile", TileInfo(src_tensor.data_type(), sampler.height(), sampler.width()));
+
+ writer.op_load(src, src_tensor, sampler);
+ writer.op_load(other, src_tensor, sampler);
+ writer.op_load(dst, dst_tensor, sampler);
+
+ auto test = dst ^ src ^ other;
+ auto other_test = logical_and(dst, src, other);
+ writer.op_assign(dst, logical_and(dst, src, other));
+ writer.op_assign(dst, test);
+ writer.op_assign(dst, other_test);
+ writer.op_assign(dst, operator^(operator^(dst, src), other));
+
+ writer.op_if(exp(src) == dst, [&]{
+ writer.op_binary_expression(dst, src, BinaryOp::Add, src);
+ }).op_else_if(exp(src) > dst, [&]{
+ writer.op_binary_expression(dst, src, BinaryOp::Add, src);
+ }).op_else([&] {
+ writer.op_assign(dst, src);
+ });
+
+ writer.op_assign(dst, src + src * src);
+ writer.op_assign(dst, src * max(src, dst) + src);
+ writer.op_assign(dst, src * select(src, dst, src) + src);
+
+ writer.op_assign(dst, src ^ dst);
+ writer.op_assign(dst, ~src);
+
+ writer.op_for_loop(dst < src, dst += src, [&]{
+ writer.op_assign(dst, src + dst);
+ });
+
+ writer.op_assign(dst += src);
+ writer.op_assign(dst += exp(src));
+
+ std::cout << "======== KERNEL ========" << std::endl;
+ std::cout << writer.generate_code() << std::endl;
+} \ No newline at end of file
diff --git a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
index 146fdac53e..c116e62650 100644
--- a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
+++ b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h
@@ -230,16 +230,17 @@ public:
*/
void op_else(const std::function<void()> &body);
- /** Write for-loops: `for(; <var> <cond_op> <cond_value>; <update_op> <update_value>) { body }`.
+ /** Write for-loops: `for(; <var> <cond_op> <cond_value>; <var> <update_op> <update_value>) { body }`.
*
* @param[in] var_name The name of the variable used in condition.
* @param[in] cond_op The relational binary operator used in condition.
* @param[in] cond_value_name The value which the variable is compared against.
+ * @param[in] update_var_name The name of the variable which is updated.
* @param[in] update_op The assignment operator used for updating the update value.
* @param[in, out] update_value The value which is updated at every iteration.
* @param[in] body The body of the for-loop.
*/
- void op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body);
+ void op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, const TileOperand &update_var_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body);
/** Write the return statement: `return;`
*/
diff --git a/compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h b/compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h
new file mode 100644
index 0000000000..a8be859680
--- /dev/null
+++ b/compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h
@@ -0,0 +1,1268 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CKW_INCLUDE_CKW_KERNELWRITERHELPER_H
+#define CKW_INCLUDE_CKW_KERNELWRITERHELPER_H
+
+#include "ckw/KernelWriter.h"
+#include "ckw/TensorOperand.h"
+#include "ckw/TileOperand.h"
+
+#include <iostream>
+#include <type_traits>
+
+#include <iostream>
+
+/*
+ * By including this header file you will be able to supplement the default
+ * Compute Kernel Writer API with additional syntax to help ease the use of CKW.
+ *
+ * To use the KernelWriterHelper you need to wrap your instance of KernelWriter
+ * (or any class deriving from KernelWriter):
+ * KernelWriterHelper<KernelWriter> writer;
+ * The resulting writer object comprises the original KernelWriter
+ * functionality (drop-in replacement), but extends the syntax as follows.
+ *
+ * Common functions/operators have natural syntax:
+ * 1. Unary expressions:
+ * writer.op_assign(dst, !src); // Logical NOT
+ * writer.op_assign(dst, ~src); // Bitwise NOT
+ *
+ * 2. Binary expressions:
+ * writer.op_assign(dst, lhs + rhs); // Addition
+ * writer.op_assign(dst, lhs - rhs); // Subtraction
+ * writer.op_assign(dst, lhs * rhs); // Multiplication
+ * writer.op_assign(dst, lhs / rhs); // Division
+ * writer.op_assign(dst, lhs % rhs); // Modulo
+ * writer.op_assign(dst, lhs == rhs); // Equality
+ * writer.op_assign(dst, lhs < rhs); // Less-than
+ * writer.op_assign(dst, lhs <= rhs); // Less-than-or-equal
+ * writer.op_assign(dst, lhs > rhs); // Greater-than
+ * writer.op_assign(dst, lhs >= rhs); // Greater-than-or-equal
+ * writer.op_assign(dst, lhs ^ rhs); // Bitwise XOR
+ * writer.op_assign(dst, logical_and(lhs, rhs)); // Logical AND
+ * writer.op_assign(dst, logical_or(lhs, rhs)); // Logical OR
+ *
+ * 3. Unary elementwise functions:
+ * writer.op_assign(dst, exp(src)); // Exponent
+ * writer.op_assign(dst, tanh(src)); // Hyperbolic tangent
+ * writer.op_assign(dst, sqrt(src)); // Square root
+ * writer.op_assign(dst, erf(src)); // Error function
+ * writer.op_assign(dst, fabs(src)); // Absolute of floating-point number
+ * writer.op_assign(dst, log(src)); // Natural logarithm
+ * writer.op_assign(dst, round(src)); // Round
+ * writer.op_assign(dst, sizeOf(src)); // sizeof
+ *
+ * 4. Binary elementwise functions:
+ * writer.op_assign(dst, max(first, second)); // Max
+ * writer.op_assign(dst, min(first, second)); // Min
+ *
+ * 5. Ternary elementwise functions:
+ * writer.op_assign(dst, select(first, second, third)); // Select
+ *
+ * NOTE: All the above examples support nesting, so you could write
+ * something like: writer.op_assign(dst, src * (log(arg) + sqrt(abs(arg)));
+ *
+ *
+ * 6. If-statements. The preceding syntax also allows easier writing of if-statements:
+ * writer.op_if(<cond>, <body>);
+ *
+ * For example:
+ * writer.op_if(exp(first_arg) == dst, [&]{
+ * //...
+ * }).op_else_if(exp(first_arg) > dst, [&]{
+ * //...
+ * }).op_else([&] {
+ * //...
+ * });
+ *
+ * 7. For-loops. A similar syntax exists for for-loops:
+ * writer.op_for_loop(<cond>, <updater>, <body>);
+ *
+ * For example:
+ * writer.op_for_loop(index < limit, index += step, [&]{
+ * //...
+ * });
+ *
+ * NOTE: There are limitations on the for-loop <cond> and <updater> parameters.
+ * In neither the <cond> (Binary expression) or <updater> (Increment/Decrement)
+ * is it allowed to use nesting. For example, `(index + other) < limit` and
+ * `index < round(limit)` are invalid <cond> parameters. This is because the
+ * semantics of for-loops rely on the condition being evaluated at every iteration,
+ * but as temporary variables might be defined for nested expressions the semantics
+ * cannot be guaranteed.
+ */
+
+namespace ckw
+{
+
+// ==================================================
+// Type traits
+// ==================================================
+
+/** Specifies if the type can be used as an operand for functions (e.g. max), operations (e.g. *), or assignments. */
+template <typename T>
+struct can_be_operand : ::std::false_type
+{
+};
+
+/** Specifies if the type can be assigned/written to. */
+template <typename T>
+struct can_be_assigned : ::std::false_type
+{
+};
+
+template <>
+struct can_be_operand<TileOperand &> : ::std::true_type
+{
+};
+
+template <>
+struct can_be_assigned<TileOperand &> : ::std::true_type
+{
+};
+
+// ==================================================
+// Assignment
+// ==================================================
+
+/** AST node for assignments.
+ *
+ * Note that \p TRight must be an operand, and \p TLeft must be assignable.
+ *
+ * @tparam TLeft The type of the destination of the assignment.
+ * @tparam TRight The type of the source assigned to the destination.
+ */
+template <typename TLeft, typename TRight, typename = ::std::enable_if<can_be_operand<TRight>::value && can_be_assigned<TLeft>::value>>
+struct Assignment
+{
+ TLeft lhs;
+ TRight rhs;
+ AssignmentOp opcode;
+};
+
+/** Represents the expression: `\p lhs += \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the assignment.
+ * @tparam TRight The type of the RHS of the assignment.
+ * @param[in] lhs The LHS of the assignment.
+ * @param[in] rhs The RHS of the assignment.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline Assignment<TLeft, TRight> operator+=(TLeft &&lhs, TRight &&rhs)
+{
+ return Assignment<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), AssignmentOp::Increment };
+}
+
+/** Represents the expression: `\p lhs -= \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the assignment.
+ * @tparam TRight The type of the RHS of the assignment.
+ * @param[in] lhs The LHS of the assignment.
+ * @param[in] rhs The RHS of the assignment.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline Assignment<TLeft, TRight> operator-=(TLeft &&lhs, TRight &&rhs)
+{
+ return Assignment<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), AssignmentOp::Decrement };
+}
+
+// ==================================================
+// Unary expression
+// ==================================================
+
+/** AST node for unary expressions.
+ *
+ * Note that \p TSrc must be an operand.
+ *
+ * @tparam TSrc The type of the argument to the expression.
+ */
+template <typename TSrc, typename = ::std::enable_if<can_be_operand<TSrc>::value>>
+struct UnaryExpression
+{
+ TSrc src;
+ UnaryOp opcode;
+};
+
+template <typename TLeft>
+struct can_be_operand<UnaryExpression<TLeft>> : ::std::true_type
+{
+};
+
+/** Represents the expression: `!\p src`.
+ *
+ * @tparam TSrc The type of the argument.
+ * @param[in] src The argument.
+ * @return The resulting AST node.
+ */
+template <typename TSrc>
+inline UnaryExpression<TSrc> operator!(TSrc &&src)
+{
+ return UnaryExpression<TSrc>{ std::forward<TSrc>(src), UnaryOp::LogicalNot };
+}
+
+/** Represents the expression: `~\p src`.
+ *
+ * @tparam TSrc The type of the argument.
+ * @param[in] src The argument.
+ * @return The resulting AST node.
+ */
+template <typename TSrc>
+inline UnaryExpression<TSrc> operator~(TSrc &&src)
+{
+ return UnaryExpression<TSrc>{ std::forward<TSrc>(src), UnaryOp::BitwiseNot };
+}
+
+// ==================================================
+// Binary expressions
+// ==================================================
+
+/** AST node for binary expressions.
+ *
+ * Note that both \p TLeft and \p TRight must be operands.
+ *
+ * @tparam TLeft The type of the left argument of the expression.
+ * @tparam TRight The type of the right argument of the expression.
+ */
+template <typename TLeft, typename TRight, typename = ::std::enable_if_t<can_be_operand<TLeft>::value && can_be_operand<TRight>::value>>
+struct BinaryExpression
+{
+ TLeft lhs;
+ TRight rhs;
+ BinaryOp opcode;
+};
+
+template <typename TLeft, typename TRight>
+struct can_be_operand<BinaryExpression<TLeft, TRight>> : ::std::true_type
+{
+};
+
+/** Represents the expression: `\p lhs + \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> operator+(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Add };
+}
+
+/** Represents the expression: `\p lhs - \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> operator-(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Sub };
+}
+
+/** Represents the expression: `\p lhs * \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> operator*(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Mul };
+}
+
+/** Represents the expression: `\p lhs / \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> operator/(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Div };
+}
+
+/** Represents the expression: `\p lhs % \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> operator%(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Mod };
+}
+
+/** Represents the expression: `\p lhs == \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> operator==(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Equal };
+}
+
+/** Represents the expression: `\p lhs < \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> operator<(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Less };
+}
+
+/** Represents the expression: `\p lhs <= \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> operator<=(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LessEqual };
+}
+
+/** Represents the expression: `\p lhs > \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> operator>(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::Greater };
+}
+
+/** Represents the expression: `\p lhs >= \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> operator>=(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::GreaterEqual };
+}
+
+/** Represents the expression: `\p lhs ^ \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> operator^(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::BitwiseXOR };
+}
+
+/** Represents the expression: `\p lhs && \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> logical_and(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalAnd };
+}
+
+/** Represents the expression: `\p lhs && \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight, typename... TOps>
+inline BinaryExpression<BinaryExpression<TLeft, TRight>, TOps...> logical_and(TLeft &&lhs, TRight &&rhs, TOps &&...ops)
+{
+ return logical_and(
+ BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalAnd },
+ std::forward<TOps>(ops)...);
+}
+
+/** Represents the expression: `\p lhs || \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight>
+inline BinaryExpression<TLeft, TRight> logical_or(TLeft &&lhs, TRight &&rhs)
+{
+ return BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalOr };
+}
+
+/** Represents the expression: `\p lhs || \p rhs`.
+ *
+ * @tparam TLeft The type of the LHS of the expression.
+ * @tparam TRight The type of the RHS of the expression.
+ * @param[in] lhs The LHS of the expression.
+ * @param[in] rhs The RHS of the expression.
+ * @return The resulting AST node.
+ */
+template <typename TLeft, typename TRight, typename... TOps>
+inline BinaryExpression<BinaryExpression<TLeft, TRight>, TOps...> logical_or(TLeft &&lhs, TRight &&rhs, TOps &&...ops)
+{
+ return logical_or(
+ BinaryExpression<TLeft, TRight>{ std::forward<TLeft>(lhs), std::forward<TRight>(rhs), BinaryOp::LogicalOr },
+ std::forward<TOps>(ops)...);
+}
+
+// ==================================================
+// Unary elementwise functions
+// ==================================================
+
+/** AST node for unary elementwise functions.
+ *
+ * Note that \p TSrc must be an operand.
+ *
+ * @tparam TSrc The type of the argument to the function.
+ */
+template <typename TSrc, typename = ::std::enable_if<can_be_operand<TSrc>::value>>
+struct UnaryElementwiseFunction
+{
+ TSrc src;
+ UnaryFunction opcode;
+};
+
+template <typename TLeft>
+struct can_be_operand<UnaryElementwiseFunction<TLeft>> : ::std::true_type
+{
+};
+
+/** Represents the expression: `exp(\p src)`.
+ *
+ * @tparam TSrc The type of the argument.
+ * @param[in] src The argument.
+ * @return The resulting AST node.
+ */
+template <typename TSrc>
+UnaryElementwiseFunction<TSrc> exp(TSrc &&src)
+{
+ return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Exp };
+}
+
+/** Represents the expression: `tanh(\p src)`.
+ *
+ * @tparam TSrc The type of the argument.
+ * @param[in] src The argument.
+ * @return The resulting AST node.
+ */
+template <typename TSrc>
+UnaryElementwiseFunction<TSrc> tanh(TSrc &&src)
+{
+ return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Tanh };
+}
+
+/** Represents the expression: `sqrt(\p src)`.
+ *
+ * @tparam TSrc The type of the argument.
+ * @param[in] src The argument.
+ * @return The resulting AST node.
+ */
+template <typename TSrc>
+UnaryElementwiseFunction<TSrc> sqrt(TSrc &&src)
+{
+ return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Sqrt };
+}
+
+/** Represents the expression: `erf(\p src)`.
+ *
+ * @tparam TSrc The type of the argument.
+ * @param[in] src The argument.
+ * @return The resulting AST node.
+ */
+template <typename TSrc>
+UnaryElementwiseFunction<TSrc> erf(TSrc &&src)
+{
+ return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Erf };
+}
+
+/** Represents the expression: `fabs(\p src)`.
+ *
+ * @tparam TSrc The type of the argument.
+ * @param[in] src The argument.
+ * @return The resulting AST node.
+ */
+template <typename TSrc>
+UnaryElementwiseFunction<TSrc> fabs(TSrc &&src)
+{
+ return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Fabs };
+}
+
+/** Represents the expression: `log(\p src)`.
+ *
+ * @tparam TSrc The type of the argument.
+ * @param[in] src The argument.
+ * @return The resulting AST node.
+ */
+template <typename TSrc>
+UnaryElementwiseFunction<TSrc> log(TSrc &&src)
+{
+ return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Log };
+}
+
+/** Represents the expression: `round(\p src)`.
+ *
+ * @tparam TSrc The type of the argument.
+ * @param[in] src The argument.
+ * @return The resulting AST node.
+ */
+template <typename TSrc>
+UnaryElementwiseFunction<TSrc> round(TSrc &&src)
+{
+ return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::Round };
+}
+
+/** Represents the expression: `sizeof(\p src)`.
+ *
+ * @tparam TSrc The type of the argument.
+ * @param[in] src The argument.
+ * @return The resulting AST node.
+ */
+template <typename TSrc>
+UnaryElementwiseFunction<TSrc> sizeOf(TSrc &&src)
+{
+ return UnaryElementwiseFunction<TSrc>{ std::forward<TSrc>(src), UnaryFunction::SizeOf };
+}
+
+// ==================================================
+// Binary elementwise functions
+// ==================================================
+
+/** AST node for binary elementwise functions.
+ *
+ * Note that both \p TFirst and \p TSecond must be operands.
+ *
+ * @tparam TFirst The type of the left argument of the function.
+ * @tparam TSecond The type of the right argument of the function.
+ */
+template <typename TFirst, typename TSecond, typename = ::std::enable_if<can_be_operand<TFirst>::value && can_be_operand<TSecond>::value>>
+struct BinaryElementwiseFunction
+{
+ TFirst first;
+ TSecond second;
+ BinaryFunction opcode;
+};
+
+template <typename TFirst, typename TSecond>
+struct can_be_operand<BinaryElementwiseFunction<TFirst, TSecond>> : ::std::true_type
+{
+};
+
+/** Represents the function call: `max(\p first, \p second)`.
+ *
+ * @tparam TFirst The type of the first argument.
+ * @tparam TSecond The type of the second argument.
+ * @param[in] first The first argument.
+ * @param[in] second The second argument.
+ * @return The resulting AST node.
+ */
+template <typename TFirst, typename TSecond>
+BinaryElementwiseFunction<TFirst, TSecond> max(TFirst &&first, TSecond &&second)
+{
+ return BinaryElementwiseFunction<TFirst, TSecond>{ std::forward<TFirst>(first), std::forward<TSecond>(second), BinaryFunction::Max };
+}
+
+/** Represents the function call: `min(\p first, \p second)`.
+ *
+ * @tparam TFirst The type of the first argument.
+ * @tparam TSecond The type of the second argument.
+ * @param[in] first The first argument.
+ * @param[in] second The second argument.
+ * @return The resulting AST node.
+ */
+template <typename TFirst, typename TSecond>
+BinaryElementwiseFunction<TFirst, TSecond> min(TFirst &&first, TSecond &&second)
+{
+ return BinaryElementwiseFunction<TFirst, TSecond>{ std::forward<TFirst>(first), std::forward<TSecond>(second), BinaryFunction::Min };
+}
+
+// ==================================================
+// Ternary elementwise functions
+// ==================================================
+
+/** AST node for ternary elementwise functions.
+ *
+ * Note that \p TFirst, \p TSecond, and \p TThird all must be operands.
+ *
+ * @tparam TFirst The type of the first argument to the function.
+ * @tparam TSecond The type of the second argument to the function.
+ * @tparam TThird The type of the third argument to the function.
+ */
+template <typename TFirst, typename TSecond, typename TThird, typename = ::std::enable_if<can_be_operand<TFirst>::value && can_be_operand<TSecond>::value && can_be_operand<TThird>::value>>
+struct TernaryElementwiseFunction
+{
+ TFirst first;
+ TSecond second;
+ TThird third;
+ TernaryFunction opcode;
+};
+
+template <typename TFirst, typename TSecond, typename TThird>
+struct can_be_operand<TernaryElementwiseFunction<TFirst, TSecond, TThird>> : ::std::true_type
+{
+};
+
+/** Represents the function call: `select(\p first, \p second, \p third)`.
+ *
+ * @tparam TFirst The type of the first argument.
+ * @tparam TSecond The type of the second argument.
+ * @tparam TThird The type of the third argument.
+ * @param[in] first The first argument.
+ * @param[in] second The second argument.
+ * @param[in] third The third argument.
+ * @return The resulting AST node.
+ */
+template <typename TFirst, typename TSecond, typename TThird>
+TernaryElementwiseFunction<TFirst, TSecond, TThird> select(TFirst &&first, TSecond &&second, TThird &&third)
+{
+ return TernaryElementwiseFunction<TFirst, TSecond, TThird>{ std::forward<TFirst>(first), std::forward<TSecond>(second), std::forward<TThird>(third), TernaryFunction::Select };
+}
+
+/** Helper class used to extend a KernelWriter with additional functionality
+ * in order to make writing easier.
+ *
+ * This extension automatically handles creation of temporary variables, and
+ * allows nested function calls and operations.
+ *
+ * @tparam TWriter The type of KernelWriter to be overloaded. This must inherit from KernelWriter.
+ */
+template <class TWriter, typename = std::enable_if<std::is_base_of<KernelWriter, TWriter>::value>>
+class KernelWriterHelper : public TWriter
+{
+public:
+ using TWriter::TWriter;
+
+ // ==================================================
+ // If-statements
+ // ==================================================
+
+ // Un-hide original implementation, in case the original implementation is required.
+ using TWriter::op_if;
+
+ /** Represents the if-statement: `if(\p cond) { \p body }`.
+ *
+ * The BinaryExpression is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] cond The BinaryExpression representing the condition.
+ * @param[in] body The body of the if-statement.
+ */
+ KernelWriterHelper<TWriter> &op_if(const BinaryExpression<TileOperand &, TileOperand &> &cond, const std::function<void()> &body)
+ {
+ TWriter::op_if(cond.lhs, cond.opcode, cond.rhs, body);
+ return *this;
+ }
+
+ /** Represents the if-statement: `if(\p cond) { \p body }`.
+ *
+ * The BinaryExpression is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] cond The BinaryExpression representing the condition.
+ * @param[in] body The body of the if-statement.
+ */
+ template <typename TRight>
+ KernelWriterHelper<TWriter> &op_if(const BinaryExpression<TileOperand &, TRight> &cond, const std::function<void()> &body)
+ {
+ auto &tmp1 = declare_temp_tile(cond.lhs.tile_info());
+ op_assign(tmp1, cond.rhs);
+ TWriter::op_if(cond.lhs, cond.opcode, tmp1, body);
+ return *this;
+ }
+
+ /** Represents the if-statement: `if(\p cond) { \p body }`.
+ *
+ * The BinaryExpression is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] cond The BinaryExpression representing the condition.
+ * @param[in] body The body of the if-statement.
+ */
+ template <typename TLeft>
+ KernelWriterHelper<TWriter> &op_if(const BinaryExpression<TLeft, TileOperand &> &cond, const std::function<void()> &body)
+ {
+ auto &tmp1 = declare_temp_tile(cond.rhs.tile_info());
+ op_assign(tmp1, cond.lhs);
+ TWriter::op_if(tmp1, cond.opcode, cond.rhs, body);
+ return *this;
+ }
+
+ // Un-hide original implementation, in case the original implementation is required.
+ using TWriter::op_else_if;
+
+ /** Represents the else-if-statement: `else if(\p cond) { \p body }`.
+ *
+ * The BinaryExpression is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] cond The BinaryExpression representing the condition.
+ * @param[in] body The body of the else-if-statement.
+ */
+ KernelWriterHelper<TWriter> &op_else_if(const BinaryExpression<TileOperand &, TileOperand &> &cond, const std::function<void()> &body)
+ {
+ TWriter::op_else_if(cond.lhs, cond.opcode, cond.rhs, body);
+ return *this;
+ }
+
+ /** Represents the else-if-statement: `else if(\p cond) { \p body }`.
+ *
+ * The BinaryExpression is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] cond The BinaryExpression representing the condition.
+ * @param[in] body The body of the else-if-statement.
+ */
+ template <typename TRight>
+ KernelWriterHelper<TWriter> &op_else_if(const BinaryExpression<TileOperand &, TRight> &cond, const std::function<void()> &body)
+ {
+ auto &tmp1 = declare_temp_tile(cond.lhs.tile_info());
+ op_assign(tmp1, cond.rhs);
+ TWriter::op_else_if(cond.lhs, cond.opcode, tmp1, body);
+ return *this;
+ }
+
+ /** Represents the else-if-statement: `else if(\p cond) { \p body }`.
+ *
+ * The BinaryExpression is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] cond The BinaryExpression representing the condition.
+ * @param[in] body The body of the else-if-statement.
+ */
+ template <typename TLeft>
+ KernelWriterHelper<TWriter> &op_else_if(const BinaryExpression<TLeft, TileOperand &> &cond, const std::function<void()> &body)
+ {
+ auto &tmp1 = declare_temp_tile(cond.rhs.tile_info());
+ op_assign(tmp1, cond.lhs);
+ TWriter::op_else_if(tmp1, cond.opcode, cond.rhs, body);
+ return *this;
+ }
+
+ // ==================================================
+ // For-loops
+ // ==================================================
+
+ // Un-hide original implementation, in case the original implementation is required.
+ using TWriter::op_for_loop;
+
+ /** Represents the for-loop: `for(;\p cond; \p updater) { \p body }`.
+ *
+ * The BinaryExpression for the condition and the Assignment
+ * for the updater are unpacked and their components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] cond The BinaryExpression representing the condition.
+ * @param[in] updater The Assignment representing the updater.
+ * @param[in] body The body of the for-loop.
+ */
+ void op_for_loop(const BinaryExpression<TileOperand &, TileOperand &> &cond, const Assignment<TileOperand &, TileOperand &> &updater, const std::function<void()> &body)
+ {
+ TWriter::op_for_loop(cond.lhs, cond.opcode, cond.rhs, updater.lhs, updater.opcode, updater.rhs, body);
+ }
+
+ // ==================================================
+ // Unary expressions
+ // ==================================================
+
+ // Un-hide original implementation, in case the original implementation is required.
+ using TWriter::op_assign;
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The UnaryExpression is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The UnaryExpression representing the expression to be evaluated and assigned.
+ */
+ void op_assign(const TileOperand &dst, const UnaryExpression<TileOperand &> &exp)
+ {
+ TWriter::op_unary_expression(dst, exp.opcode, exp.src);
+ }
+
+ // ==================================================
+ // Binary expressions
+ // ==================================================
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The BinaryExpression is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
+ */
+ void op_assign(const TileOperand &dst, const BinaryExpression<TileOperand &, TileOperand &> &exp)
+ {
+ TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, exp.rhs);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The BinaryExpression is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
+ */
+ template <typename TRight>
+ void op_assign(const TileOperand &dst, const BinaryExpression<TileOperand &, TRight> &exp)
+ {
+ std::cout << "Beginning assignment!" << std::endl;
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.rhs);
+ TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, tmp1);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The BinaryExpression is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
+ */
+ template <typename TLeft>
+ void op_assign(const TileOperand &dst, const BinaryExpression<TLeft, TileOperand &> &exp)
+ {
+ std::cout << "Beginning assignment!" << std::endl;
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.lhs);
+ TWriter::op_binary_expression(dst, tmp1, exp.opcode, exp.rhs);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The BinaryExpression is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned.
+ */
+ template <typename TLeft, typename TRight>
+ void op_assign(const TileOperand &dst, const BinaryExpression<TLeft, TRight> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ auto &tmp2 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.lhs);
+ op_assign(tmp2, exp.rhs);
+ TWriter::op_binary_expression(dst, tmp1, exp.opcode, tmp2);
+ }
+
+ // ==================================================
+ // Unary elementwise functions
+ // ==================================================
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The UnaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ void op_assign(const TileOperand &dst, const UnaryElementwiseFunction<TileOperand &> &exp)
+ {
+ TWriter::op_unary_elementwise_function(dst, exp.opcode, exp.src);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The UnaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ template <typename TArg>
+ void op_assign(const TileOperand &dst, const UnaryElementwiseFunction<TArg> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.lhs);
+ TWriter::op_unary_elementwise_function(dst, exp.opcode, tmp1);
+ }
+
+ // ==================================================
+ // Binary elementwise functions
+ // ==================================================
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The BinaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TileOperand &, TileOperand &> &exp)
+ {
+ TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, exp.second);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The BinaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ template <typename TRight>
+ void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TileOperand &, TRight> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.second);
+ TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, tmp1);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The BinaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ template <typename TLeft>
+ void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TLeft, TileOperand &> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.first);
+ TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, exp.second);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The BinaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ template <typename TLeft, typename TRight>
+ void op_assign(const TileOperand &dst, const BinaryElementwiseFunction<TLeft, TRight> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ auto &tmp2 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.first);
+ op_assign(tmp2, exp.second);
+ TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, tmp2);
+ }
+
+ // ==================================================
+ // Ternary elementwise functions
+ // ==================================================
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The TernaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TileOperand &, TileOperand &> &exp)
+ {
+ TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, exp.third);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The TernaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ template <typename TFirst>
+ void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TileOperand &, TileOperand &> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.first);
+ TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, exp.third);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The TernaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ template <typename TSecond>
+ void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TSecond, TileOperand &> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.second);
+ TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, exp.third);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The TernaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ template <typename TThird>
+ void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TileOperand &, TThird> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.third);
+ TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, tmp1);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The TernaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ template <typename TFirst, typename TSecond>
+ void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TSecond, TileOperand &> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ auto &tmp2 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.first);
+ op_assign(tmp2, exp.second);
+ TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, exp.third);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The TernaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ template <typename TFirst, typename TThird>
+ void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TileOperand &, TThird> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ auto &tmp2 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.first);
+ op_assign(tmp2, exp.third);
+ TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, tmp2);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The TernaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ template <typename TSecond, typename TThird>
+ void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TileOperand &, TSecond, TThird> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(dst.tile_info());
+ auto &tmp2 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.second);
+ op_assign(tmp2, exp.third);
+ TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, tmp2);
+ }
+
+ /** Represents the assignment: `\p dst = \p exp`.
+ *
+ * The TernaryElementwiseFunction is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] dst The tile which is assigned to.
+ * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned.
+ */
+ template <typename TFirst, typename TSecond, typename TThird>
+ void op_assign(const TileOperand &dst, const TernaryElementwiseFunction<TFirst, TSecond, TThird> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(dst.tile_info(), dst.tile_info(), dst.tile_info());
+ auto &tmp2 = declare_temp_tile(dst.tile_info());
+ auto &tmp3 = declare_temp_tile(dst.tile_info());
+ op_assign(tmp1, exp.first);
+ op_assign(tmp2, exp.second);
+ op_assign(tmp3, exp.third);
+ TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, tmp3);
+ }
+
+ // ==================================================
+ // Assignments
+ // ==================================================
+
+ /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`.
+ *
+ * The Assignment is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @param[in] exp The Assignment representing the expression to be evaluated.
+ */
+ void op_assign(const Assignment<TileOperand &, TileOperand &> &exp)
+ {
+ if(exp.opcode == AssignmentOp::Increment)
+ {
+ TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Add, exp.rhs);
+ }
+ else if(exp.opcode == AssignmentOp::Decrement)
+ {
+ TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Sub, exp.rhs);
+ }
+ }
+
+ /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`.
+ *
+ * The Assignment is unpacked and its components are forwarded to
+ * the underlying KernelWriter's implementation.
+ *
+ * @tparam TRight The type of the RHS of the assignment.
+ * @param[in] exp The Assignment representing the expression to be evaluated.
+ */
+ template <typename TRight>
+ void op_assign(const Assignment<TileOperand &, TRight> &exp)
+ {
+ auto &tmp1 = declare_temp_tile(exp.lhs.tile_info());
+ op_assign(tmp1, exp.rhs);
+ op_assign(Assignment<TileOperand &, TileOperand &>{ exp.lhs, tmp1, exp.opcode });
+ }
+
+private:
+ unsigned int temp_var_counter = 0;
+
+ /** Return the current counter value, then increment it.
+ *
+ * @return The current counter value.
+ */
+ int next_ctr()
+ {
+ return temp_var_counter++;
+ }
+
+ /** Gets the next temporary variable counter value,
+ * and returns a suitable temporary variable name.
+ *
+ * @return A temporary variable name.
+ */
+ std::string next_tmp_var_name()
+ {
+ return "tmp_" + std::to_string(next_ctr());
+ }
+
+ /** Returns the argument.
+ *
+ * Used for recursion with the variadic function version of this function.
+ *
+ * @param[in] arg The TileInfo to return.
+ * @return The \p arg.
+ */
+ TileInfo get_largest_size(const TileInfo &arg)
+ {
+ return arg;
+ }
+
+ /** Returns a TileInfo object where the size in each dimension (width, height) is the largest
+ * of either TileInfo argument in the corresponding dimension.
+ *
+ * @tparam TOps Must be of TileInfo type.
+ * @param[in] first A TileInfo object.
+ * @param[in] second A TileInfo object.
+ * @param[in] ops A number of TileInfo objects.
+ * @return A TileInfo object which represents the largest shape in each dimension across the arguments.
+ */
+ template <typename... TOps, typename = ::std::enable_if_t<std::is_same<TOps..., TileInfo>::value>>
+ TileInfo get_largest_size(const TileInfo &first, const TileInfo &second, const TOps &...ops)
+ {
+ TileInfo largest = {
+ first.data_type(),
+ std::max(first.width(), second.width()),
+ std::max(first.height(), second.height())
+ };
+ return get_largest_size(largest, ops...);
+ }
+
+ /** Helper function to define a suitable TileOperand with appropriate TileInfo
+ * such that broadcasting is taken into account, based on the arguments provided.
+ *
+ * @tparam TArgs Must be of TileInfo type.
+ * @param[in] args A number of TileInfo which determine the shape of the TileOperand to declare.
+ * @return A newly created TileOperand.
+ */
+ template <typename... TArgs, typename = ::std::enable_if_t<std::is_same<TArgs..., TileInfo>::value>>
+ TileOperand &declare_temp_tile(const TArgs &...args)
+ {
+ return TWriter::declare_tile(next_tmp_var_name().c_str(), get_largest_size(args...));
+ }
+};
+
+} // namespace ckw
+
+#endif // CKW_INCLUDE_CKW_KERNELWRITERHELPER_H
diff --git a/compute_kernel_writer/prototype/include/ckw/types/Functions.h b/compute_kernel_writer/prototype/include/ckw/types/Functions.h
index 68146cb1c8..2dd5ed0b3d 100644
--- a/compute_kernel_writer/prototype/include/ckw/types/Functions.h
+++ b/compute_kernel_writer/prototype/include/ckw/types/Functions.h
@@ -37,7 +37,6 @@ enum class UnaryFunction : int32_t
Sqrt = 0x0002,
Erf = 0x0003,
Fabs = 0x0004,
- IsGreaterEqual = 0x0005,
Log = 0x0006,
Round = 0x0007,
diff --git a/compute_kernel_writer/prototype/include/ckw/types/Operators.h b/compute_kernel_writer/prototype/include/ckw/types/Operators.h
index 172650d5ae..14a88c91b4 100644
--- a/compute_kernel_writer/prototype/include/ckw/types/Operators.h
+++ b/compute_kernel_writer/prototype/include/ckw/types/Operators.h
@@ -68,8 +68,8 @@ enum class BinaryOp : int32_t
enum class AssignmentOp : int32_t
{
// Unary
- Increment = 0x0000, // +=
- Decrement = 0x0001, // -=
+ Increment = 0x0000, // +=
+ Decrement = 0x0001, // -=
};
} // namespace ckw
diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp
index 1ac9ede5b5..9122e518b4 100644
--- a/compute_kernel_writer/prototype/src/KernelWriter.cpp
+++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp
@@ -270,13 +270,14 @@ void KernelWriter::op_else(const std::function<void()> &body)
_impl->compound_statement_end();
}
-void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body)
+void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, const TileOperand &update_var_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function<void()> &body)
{
auto impl_var_name = var_name.create_impl_operand(_impl.get());
auto impl_cond_value_name = cond_value_name.create_impl_operand(_impl.get());
+ auto impl_update_var_name = update_var_name.create_impl_operand(_impl.get());
auto impl_update_value_name = update_value_name.create_impl_operand(_impl.get());
- _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, update_op, impl_update_value_name);
+ _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, impl_update_var_name, update_op, impl_update_value_name);
_impl->compound_statement_begin();
body();
_impl->compound_statement_end();
diff --git a/compute_kernel_writer/prototype/src/Prototype.h b/compute_kernel_writer/prototype/src/Prototype.h
index 72fa419fc2..05c7306e3a 100644
--- a/compute_kernel_writer/prototype/src/Prototype.h
+++ b/compute_kernel_writer/prototype/src/Prototype.h
@@ -2498,7 +2498,7 @@ public:
virtual void op_else_header() = 0;
- virtual void op_for_loop_header(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value, AssignmentOp update_op, const Operand &update_value) = 0;
+ virtual void op_for_loop_header(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value, const Operand &update_var, AssignmentOp update_op, const Operand &update_value) = 0;
virtual void op_load_indirect(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y_indirect, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0;
@@ -3654,9 +3654,6 @@ public:
case UnaryFunction::Fabs:
_data->code += "fabs(";
break;
- case UnaryFunction::IsGreaterEqual:
- _data->code += "isgreaterequal(";
- break;
case UnaryFunction::Log:
_data->code += "log(";
break;
@@ -3798,11 +3795,12 @@ public:
_data->code += "else\n";
}
- void op_for_loop_header(const Operand& var_name, BinaryOp cond_op, const Operand& cond_value_name, AssignmentOp update_op, const Operand& update_value_name) override
+ void op_for_loop_header(const Operand& var_name, BinaryOp cond_op, const Operand& cond_value_name, const Operand &update_var_name, AssignmentOp update_op, const Operand& update_value_name) override
{
OperandUnpacker operands(_data->tiles, _data->arguments);
const IVectorTile *var = operands.unpack(var_name);
const IVectorTile *cond_value = operands.unpack(cond_value_name);
+ const IVectorTile *update_var = operands.unpack(update_var_name);
const IVectorTile *update_value = operands.unpack(update_value_name);
const int32_t dst_w = var->format().w;
@@ -3818,7 +3816,7 @@ public:
_data->code += " ";
_data->code += to_string(cond_op);
_data->code += " " + cond_value->scalar(0, 0).str + "; ";
- _data->code += var->scalar(0, 0).str;
+ _data->code += update_var->scalar(0, 0).str;
_data->code += " ";
_data->code += to_string(update_op);
_data->code += " " + update_value->scalar(0, 0).str + ")";