From fab6c210b37f1fa6b3e37a2583b18f8e4b5a4f12 Mon Sep 17 00:00:00 2001 From: Nikolaj Jensen Date: Tue, 27 Jun 2023 14:13:24 +0100 Subject: Design wrapper around CKW for easier writing Signed-off-by: Nikolaj Jensen Change-Id: I114cdedcaf05c6abde046741837eeb73b813aa9d Signed-off-by: Nikolaj Jensen Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/532180 Tested-by: bsgcomp Reviewed-by: Viet-Hoa Do Comments-Addressed: bsgcomp Signed-off-by: Nikolaj Jensen Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9921 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- compute_kernel_writer/prototype/CMakeLists.txt | 3 + .../prototype/examples/writer_helper.cpp | 118 ++ .../prototype/include/ckw/KernelWriter.h | 5 +- .../prototype/include/ckw/KernelWriterHelper.h | 1268 ++++++++++++++++++++ .../prototype/include/ckw/types/Functions.h | 1 - .../prototype/include/ckw/types/Operators.h | 4 +- .../prototype/src/KernelWriter.cpp | 5 +- compute_kernel_writer/prototype/src/Prototype.h | 10 +- 8 files changed, 1401 insertions(+), 13 deletions(-) create mode 100644 compute_kernel_writer/prototype/examples/writer_helper.cpp create mode 100644 compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h diff --git a/compute_kernel_writer/prototype/CMakeLists.txt b/compute_kernel_writer/prototype/CMakeLists.txt index 3d6a192378..13d1ae8fc4 100644 --- a/compute_kernel_writer/prototype/CMakeLists.txt +++ b/compute_kernel_writer/prototype/CMakeLists.txt @@ -73,3 +73,6 @@ target_link_libraries(ckw_prototype_examples_common PUBLIC ckw_prototype) add_executable(ckw_prototype_examples_add_exp_store examples/add_exp_store.cpp) target_link_libraries(ckw_prototype_examples_add_exp_store PUBLIC ckw_prototype_examples_common) + +add_executable(writer_helper examples/writer_helper.cpp) +target_link_libraries(writer_helper PUBLIC ckw_prototype) diff --git a/compute_kernel_writer/prototype/examples/writer_helper.cpp b/compute_kernel_writer/prototype/examples/writer_helper.cpp new file mode 100644 index 0000000000..ccef92dcdf --- /dev/null +++ b/compute_kernel_writer/prototype/examples/writer_helper.cpp @@ -0,0 +1,118 @@ +/* +* Copyright (c) 2023 Arm Limited. +* +* SPDX-License-Identifier: MIT +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to +* deal in the Software without restriction, including without limitation the +* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +* sell copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in all +* copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ + +#include "ckw/KernelWriter.h" +#include "../include/ckw/KernelWriterHelper.h" +#include "ckw/TensorTileSampler.h" + +#include + +using namespace ckw; + +TensorTileSampler create_simple_sampler(KernelWriter& writer) +{ + TensorTileSampler sampler; + + constexpr int32_t m0 = 1; + constexpr int32_t n0 = 1; + + auto &gid_0 = writer.declare_tile("gid_0", DataType::Int32); + auto &gid_1 = writer.declare_tile("gid_1", DataType::Int32); + auto &gid_2 = writer.declare_tile("gid_2", DataType::Int32); + + auto &const_0 = writer.declare_tile("0", 0); + + writer.op_get_global_id(gid_0, 0); + writer.op_get_global_id(gid_1, 1); + writer.op_get_global_id(gid_2, 2); + + sampler.x(gid_0); + sampler.y(gid_1); + sampler.z(gid_2); + sampler.b(const_0); + + sampler.width(n0); + sampler.height(m0); + + sampler.format(TensorSamplerFormat::C_WH_1); + sampler.address_mode_x(TensorSamplerAddressModeX::None); + sampler.address_mode_y(TensorSamplerAddressModeY::ClampToBorder); + sampler.address_mode_z(TensorSamplerAddressModeZ::Skip); + + return sampler; +} + +int main() +{ + Kernel kernel("test", GpuTargetLanguage::OpenCL); + KernelWriterHelper writer(kernel); + + const TensorInfo src_info(DataType::Fp32, TensorShape({ 1, 1, 1, 1, 1 }), TensorDataLayout::Nhwc, 0); + const TensorInfo dst_info(DataType::Fp32, TensorShape({ 1, 1, 1, 1, 1 }), TensorDataLayout::Nhwc, 1); + + auto &src_tensor = writer.declare_tensor_argument("src", src_info); + auto &dst_tensor = writer.declare_tensor_argument("dst", dst_info); + + const auto sampler = create_simple_sampler(writer); + + auto &src = writer.declare_tile("src_tile", TileInfo(src_tensor.data_type(), sampler.height(), sampler.width())); + auto &other = writer.declare_tile("other_tile", TileInfo(src_tensor.data_type(), sampler.height(), sampler.width())); + auto &dst = writer.declare_tile("dst_tile", TileInfo(src_tensor.data_type(), sampler.height(), sampler.width())); + + writer.op_load(src, src_tensor, sampler); + writer.op_load(other, src_tensor, sampler); + writer.op_load(dst, dst_tensor, sampler); + + auto test = dst ^ src ^ other; + auto other_test = logical_and(dst, src, other); + writer.op_assign(dst, logical_and(dst, src, other)); + writer.op_assign(dst, test); + writer.op_assign(dst, other_test); + writer.op_assign(dst, operator^(operator^(dst, src), other)); + + writer.op_if(exp(src) == dst, [&]{ + writer.op_binary_expression(dst, src, BinaryOp::Add, src); + }).op_else_if(exp(src) > dst, [&]{ + writer.op_binary_expression(dst, src, BinaryOp::Add, src); + }).op_else([&] { + writer.op_assign(dst, src); + }); + + writer.op_assign(dst, src + src * src); + writer.op_assign(dst, src * max(src, dst) + src); + writer.op_assign(dst, src * select(src, dst, src) + src); + + writer.op_assign(dst, src ^ dst); + writer.op_assign(dst, ~src); + + writer.op_for_loop(dst < src, dst += src, [&]{ + writer.op_assign(dst, src + dst); + }); + + writer.op_assign(dst += src); + writer.op_assign(dst += exp(src)); + + std::cout << "======== KERNEL ========" << std::endl; + std::cout << writer.generate_code() << std::endl; +} \ No newline at end of file diff --git a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h index 146fdac53e..c116e62650 100644 --- a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h +++ b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h @@ -230,16 +230,17 @@ public: */ void op_else(const std::function &body); - /** Write for-loops: `for(; ; ) { body }`. + /** Write for-loops: `for(; ; ) { body }`. * * @param[in] var_name The name of the variable used in condition. * @param[in] cond_op The relational binary operator used in condition. * @param[in] cond_value_name The value which the variable is compared against. + * @param[in] update_var_name The name of the variable which is updated. * @param[in] update_op The assignment operator used for updating the update value. * @param[in, out] update_value The value which is updated at every iteration. * @param[in] body The body of the for-loop. */ - void op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function &body); + void op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, const TileOperand &update_var_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function &body); /** Write the return statement: `return;` */ diff --git a/compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h b/compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h new file mode 100644 index 0000000000..a8be859680 --- /dev/null +++ b/compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h @@ -0,0 +1,1268 @@ +/* + * Copyright (c) 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef CKW_INCLUDE_CKW_KERNELWRITERHELPER_H +#define CKW_INCLUDE_CKW_KERNELWRITERHELPER_H + +#include "ckw/KernelWriter.h" +#include "ckw/TensorOperand.h" +#include "ckw/TileOperand.h" + +#include +#include + +#include + +/* + * By including this header file you will be able to supplement the default + * Compute Kernel Writer API with additional syntax to help ease the use of CKW. + * + * To use the KernelWriterHelper you need to wrap your instance of KernelWriter + * (or any class deriving from KernelWriter): + * KernelWriterHelper writer; + * The resulting writer object comprises the original KernelWriter + * functionality (drop-in replacement), but extends the syntax as follows. + * + * Common functions/operators have natural syntax: + * 1. Unary expressions: + * writer.op_assign(dst, !src); // Logical NOT + * writer.op_assign(dst, ~src); // Bitwise NOT + * + * 2. Binary expressions: + * writer.op_assign(dst, lhs + rhs); // Addition + * writer.op_assign(dst, lhs - rhs); // Subtraction + * writer.op_assign(dst, lhs * rhs); // Multiplication + * writer.op_assign(dst, lhs / rhs); // Division + * writer.op_assign(dst, lhs % rhs); // Modulo + * writer.op_assign(dst, lhs == rhs); // Equality + * writer.op_assign(dst, lhs < rhs); // Less-than + * writer.op_assign(dst, lhs <= rhs); // Less-than-or-equal + * writer.op_assign(dst, lhs > rhs); // Greater-than + * writer.op_assign(dst, lhs >= rhs); // Greater-than-or-equal + * writer.op_assign(dst, lhs ^ rhs); // Bitwise XOR + * writer.op_assign(dst, logical_and(lhs, rhs)); // Logical AND + * writer.op_assign(dst, logical_or(lhs, rhs)); // Logical OR + * + * 3. Unary elementwise functions: + * writer.op_assign(dst, exp(src)); // Exponent + * writer.op_assign(dst, tanh(src)); // Hyperbolic tangent + * writer.op_assign(dst, sqrt(src)); // Square root + * writer.op_assign(dst, erf(src)); // Error function + * writer.op_assign(dst, fabs(src)); // Absolute of floating-point number + * writer.op_assign(dst, log(src)); // Natural logarithm + * writer.op_assign(dst, round(src)); // Round + * writer.op_assign(dst, sizeOf(src)); // sizeof + * + * 4. Binary elementwise functions: + * writer.op_assign(dst, max(first, second)); // Max + * writer.op_assign(dst, min(first, second)); // Min + * + * 5. Ternary elementwise functions: + * writer.op_assign(dst, select(first, second, third)); // Select + * + * NOTE: All the above examples support nesting, so you could write + * something like: writer.op_assign(dst, src * (log(arg) + sqrt(abs(arg))); + * + * + * 6. If-statements. The preceding syntax also allows easier writing of if-statements: + * writer.op_if(, ); + * + * For example: + * writer.op_if(exp(first_arg) == dst, [&]{ + * //... + * }).op_else_if(exp(first_arg) > dst, [&]{ + * //... + * }).op_else([&] { + * //... + * }); + * + * 7. For-loops. A similar syntax exists for for-loops: + * writer.op_for_loop(, , ); + * + * For example: + * writer.op_for_loop(index < limit, index += step, [&]{ + * //... + * }); + * + * NOTE: There are limitations on the for-loop and parameters. + * In neither the (Binary expression) or (Increment/Decrement) + * is it allowed to use nesting. For example, `(index + other) < limit` and + * `index < round(limit)` are invalid parameters. This is because the + * semantics of for-loops rely on the condition being evaluated at every iteration, + * but as temporary variables might be defined for nested expressions the semantics + * cannot be guaranteed. + */ + +namespace ckw +{ + +// ================================================== +// Type traits +// ================================================== + +/** Specifies if the type can be used as an operand for functions (e.g. max), operations (e.g. *), or assignments. */ +template +struct can_be_operand : ::std::false_type +{ +}; + +/** Specifies if the type can be assigned/written to. */ +template +struct can_be_assigned : ::std::false_type +{ +}; + +template <> +struct can_be_operand : ::std::true_type +{ +}; + +template <> +struct can_be_assigned : ::std::true_type +{ +}; + +// ================================================== +// Assignment +// ================================================== + +/** AST node for assignments. + * + * Note that \p TRight must be an operand, and \p TLeft must be assignable. + * + * @tparam TLeft The type of the destination of the assignment. + * @tparam TRight The type of the source assigned to the destination. + */ +template ::value && can_be_assigned::value>> +struct Assignment +{ + TLeft lhs; + TRight rhs; + AssignmentOp opcode; +}; + +/** Represents the expression: `\p lhs += \p rhs`. + * + * @tparam TLeft The type of the LHS of the assignment. + * @tparam TRight The type of the RHS of the assignment. + * @param[in] lhs The LHS of the assignment. + * @param[in] rhs The RHS of the assignment. + * @return The resulting AST node. + */ +template +inline Assignment operator+=(TLeft &&lhs, TRight &&rhs) +{ + return Assignment{ std::forward(lhs), std::forward(rhs), AssignmentOp::Increment }; +} + +/** Represents the expression: `\p lhs -= \p rhs`. + * + * @tparam TLeft The type of the LHS of the assignment. + * @tparam TRight The type of the RHS of the assignment. + * @param[in] lhs The LHS of the assignment. + * @param[in] rhs The RHS of the assignment. + * @return The resulting AST node. + */ +template +inline Assignment operator-=(TLeft &&lhs, TRight &&rhs) +{ + return Assignment{ std::forward(lhs), std::forward(rhs), AssignmentOp::Decrement }; +} + +// ================================================== +// Unary expression +// ================================================== + +/** AST node for unary expressions. + * + * Note that \p TSrc must be an operand. + * + * @tparam TSrc The type of the argument to the expression. + */ +template ::value>> +struct UnaryExpression +{ + TSrc src; + UnaryOp opcode; +}; + +template +struct can_be_operand> : ::std::true_type +{ +}; + +/** Represents the expression: `!\p src`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +inline UnaryExpression operator!(TSrc &&src) +{ + return UnaryExpression{ std::forward(src), UnaryOp::LogicalNot }; +} + +/** Represents the expression: `~\p src`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +inline UnaryExpression operator~(TSrc &&src) +{ + return UnaryExpression{ std::forward(src), UnaryOp::BitwiseNot }; +} + +// ================================================== +// Binary expressions +// ================================================== + +/** AST node for binary expressions. + * + * Note that both \p TLeft and \p TRight must be operands. + * + * @tparam TLeft The type of the left argument of the expression. + * @tparam TRight The type of the right argument of the expression. + */ +template ::value && can_be_operand::value>> +struct BinaryExpression +{ + TLeft lhs; + TRight rhs; + BinaryOp opcode; +}; + +template +struct can_be_operand> : ::std::true_type +{ +}; + +/** Represents the expression: `\p lhs + \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator+(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Add }; +} + +/** Represents the expression: `\p lhs - \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator-(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Sub }; +} + +/** Represents the expression: `\p lhs * \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator*(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Mul }; +} + +/** Represents the expression: `\p lhs / \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator/(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Div }; +} + +/** Represents the expression: `\p lhs % \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator%(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Mod }; +} + +/** Represents the expression: `\p lhs == \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator==(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Equal }; +} + +/** Represents the expression: `\p lhs < \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator<(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Less }; +} + +/** Represents the expression: `\p lhs <= \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator<=(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::LessEqual }; +} + +/** Represents the expression: `\p lhs > \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator>(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Greater }; +} + +/** Represents the expression: `\p lhs >= \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator>=(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::GreaterEqual }; +} + +/** Represents the expression: `\p lhs ^ \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator^(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::BitwiseXOR }; +} + +/** Represents the expression: `\p lhs && \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression logical_and(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::LogicalAnd }; +} + +/** Represents the expression: `\p lhs && \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression, TOps...> logical_and(TLeft &&lhs, TRight &&rhs, TOps &&...ops) +{ + return logical_and( + BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::LogicalAnd }, + std::forward(ops)...); +} + +/** Represents the expression: `\p lhs || \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression logical_or(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::LogicalOr }; +} + +/** Represents the expression: `\p lhs || \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression, TOps...> logical_or(TLeft &&lhs, TRight &&rhs, TOps &&...ops) +{ + return logical_or( + BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::LogicalOr }, + std::forward(ops)...); +} + +// ================================================== +// Unary elementwise functions +// ================================================== + +/** AST node for unary elementwise functions. + * + * Note that \p TSrc must be an operand. + * + * @tparam TSrc The type of the argument to the function. + */ +template ::value>> +struct UnaryElementwiseFunction +{ + TSrc src; + UnaryFunction opcode; +}; + +template +struct can_be_operand> : ::std::true_type +{ +}; + +/** Represents the expression: `exp(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction exp(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Exp }; +} + +/** Represents the expression: `tanh(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction tanh(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Tanh }; +} + +/** Represents the expression: `sqrt(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction sqrt(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Sqrt }; +} + +/** Represents the expression: `erf(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction erf(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Erf }; +} + +/** Represents the expression: `fabs(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction fabs(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Fabs }; +} + +/** Represents the expression: `log(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction log(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Log }; +} + +/** Represents the expression: `round(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction round(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Round }; +} + +/** Represents the expression: `sizeof(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction sizeOf(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::SizeOf }; +} + +// ================================================== +// Binary elementwise functions +// ================================================== + +/** AST node for binary elementwise functions. + * + * Note that both \p TFirst and \p TSecond must be operands. + * + * @tparam TFirst The type of the left argument of the function. + * @tparam TSecond The type of the right argument of the function. + */ +template ::value && can_be_operand::value>> +struct BinaryElementwiseFunction +{ + TFirst first; + TSecond second; + BinaryFunction opcode; +}; + +template +struct can_be_operand> : ::std::true_type +{ +}; + +/** Represents the function call: `max(\p first, \p second)`. + * + * @tparam TFirst The type of the first argument. + * @tparam TSecond The type of the second argument. + * @param[in] first The first argument. + * @param[in] second The second argument. + * @return The resulting AST node. + */ +template +BinaryElementwiseFunction max(TFirst &&first, TSecond &&second) +{ + return BinaryElementwiseFunction{ std::forward(first), std::forward(second), BinaryFunction::Max }; +} + +/** Represents the function call: `min(\p first, \p second)`. + * + * @tparam TFirst The type of the first argument. + * @tparam TSecond The type of the second argument. + * @param[in] first The first argument. + * @param[in] second The second argument. + * @return The resulting AST node. + */ +template +BinaryElementwiseFunction min(TFirst &&first, TSecond &&second) +{ + return BinaryElementwiseFunction{ std::forward(first), std::forward(second), BinaryFunction::Min }; +} + +// ================================================== +// Ternary elementwise functions +// ================================================== + +/** AST node for ternary elementwise functions. + * + * Note that \p TFirst, \p TSecond, and \p TThird all must be operands. + * + * @tparam TFirst The type of the first argument to the function. + * @tparam TSecond The type of the second argument to the function. + * @tparam TThird The type of the third argument to the function. + */ +template ::value && can_be_operand::value && can_be_operand::value>> +struct TernaryElementwiseFunction +{ + TFirst first; + TSecond second; + TThird third; + TernaryFunction opcode; +}; + +template +struct can_be_operand> : ::std::true_type +{ +}; + +/** Represents the function call: `select(\p first, \p second, \p third)`. + * + * @tparam TFirst The type of the first argument. + * @tparam TSecond The type of the second argument. + * @tparam TThird The type of the third argument. + * @param[in] first The first argument. + * @param[in] second The second argument. + * @param[in] third The third argument. + * @return The resulting AST node. + */ +template +TernaryElementwiseFunction select(TFirst &&first, TSecond &&second, TThird &&third) +{ + return TernaryElementwiseFunction{ std::forward(first), std::forward(second), std::forward(third), TernaryFunction::Select }; +} + +/** Helper class used to extend a KernelWriter with additional functionality + * in order to make writing easier. + * + * This extension automatically handles creation of temporary variables, and + * allows nested function calls and operations. + * + * @tparam TWriter The type of KernelWriter to be overloaded. This must inherit from KernelWriter. + */ +template ::value>> +class KernelWriterHelper : public TWriter +{ +public: + using TWriter::TWriter; + + // ================================================== + // If-statements + // ================================================== + + // Un-hide original implementation, in case the original implementation is required. + using TWriter::op_if; + + /** Represents the if-statement: `if(\p cond) { \p body }`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] body The body of the if-statement. + */ + KernelWriterHelper &op_if(const BinaryExpression &cond, const std::function &body) + { + TWriter::op_if(cond.lhs, cond.opcode, cond.rhs, body); + return *this; + } + + /** Represents the if-statement: `if(\p cond) { \p body }`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] body The body of the if-statement. + */ + template + KernelWriterHelper &op_if(const BinaryExpression &cond, const std::function &body) + { + auto &tmp1 = declare_temp_tile(cond.lhs.tile_info()); + op_assign(tmp1, cond.rhs); + TWriter::op_if(cond.lhs, cond.opcode, tmp1, body); + return *this; + } + + /** Represents the if-statement: `if(\p cond) { \p body }`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] body The body of the if-statement. + */ + template + KernelWriterHelper &op_if(const BinaryExpression &cond, const std::function &body) + { + auto &tmp1 = declare_temp_tile(cond.rhs.tile_info()); + op_assign(tmp1, cond.lhs); + TWriter::op_if(tmp1, cond.opcode, cond.rhs, body); + return *this; + } + + // Un-hide original implementation, in case the original implementation is required. + using TWriter::op_else_if; + + /** Represents the else-if-statement: `else if(\p cond) { \p body }`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] body The body of the else-if-statement. + */ + KernelWriterHelper &op_else_if(const BinaryExpression &cond, const std::function &body) + { + TWriter::op_else_if(cond.lhs, cond.opcode, cond.rhs, body); + return *this; + } + + /** Represents the else-if-statement: `else if(\p cond) { \p body }`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] body The body of the else-if-statement. + */ + template + KernelWriterHelper &op_else_if(const BinaryExpression &cond, const std::function &body) + { + auto &tmp1 = declare_temp_tile(cond.lhs.tile_info()); + op_assign(tmp1, cond.rhs); + TWriter::op_else_if(cond.lhs, cond.opcode, tmp1, body); + return *this; + } + + /** Represents the else-if-statement: `else if(\p cond) { \p body }`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] body The body of the else-if-statement. + */ + template + KernelWriterHelper &op_else_if(const BinaryExpression &cond, const std::function &body) + { + auto &tmp1 = declare_temp_tile(cond.rhs.tile_info()); + op_assign(tmp1, cond.lhs); + TWriter::op_else_if(tmp1, cond.opcode, cond.rhs, body); + return *this; + } + + // ================================================== + // For-loops + // ================================================== + + // Un-hide original implementation, in case the original implementation is required. + using TWriter::op_for_loop; + + /** Represents the for-loop: `for(;\p cond; \p updater) { \p body }`. + * + * The BinaryExpression for the condition and the Assignment + * for the updater are unpacked and their components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] updater The Assignment representing the updater. + * @param[in] body The body of the for-loop. + */ + void op_for_loop(const BinaryExpression &cond, const Assignment &updater, const std::function &body) + { + TWriter::op_for_loop(cond.lhs, cond.opcode, cond.rhs, updater.lhs, updater.opcode, updater.rhs, body); + } + + // ================================================== + // Unary expressions + // ================================================== + + // Un-hide original implementation, in case the original implementation is required. + using TWriter::op_assign; + + /** Represents the assignment: `\p dst = \p exp`. + * + * The UnaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The UnaryExpression representing the expression to be evaluated and assigned. + */ + void op_assign(const TileOperand &dst, const UnaryExpression &exp) + { + TWriter::op_unary_expression(dst, exp.opcode, exp.src); + } + + // ================================================== + // Binary expressions + // ================================================== + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned. + */ + void op_assign(const TileOperand &dst, const BinaryExpression &exp) + { + TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, exp.rhs); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const BinaryExpression &exp) + { + std::cout << "Beginning assignment!" << std::endl; + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.rhs); + TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, tmp1); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const BinaryExpression &exp) + { + std::cout << "Beginning assignment!" << std::endl; + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.lhs); + TWriter::op_binary_expression(dst, tmp1, exp.opcode, exp.rhs); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const BinaryExpression &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + auto &tmp2 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.lhs); + op_assign(tmp2, exp.rhs); + TWriter::op_binary_expression(dst, tmp1, exp.opcode, tmp2); + } + + // ================================================== + // Unary elementwise functions + // ================================================== + + /** Represents the assignment: `\p dst = \p exp`. + * + * The UnaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + void op_assign(const TileOperand &dst, const UnaryElementwiseFunction &exp) + { + TWriter::op_unary_elementwise_function(dst, exp.opcode, exp.src); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The UnaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const UnaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.lhs); + TWriter::op_unary_elementwise_function(dst, exp.opcode, tmp1); + } + + // ================================================== + // Binary elementwise functions + // ================================================== + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + void op_assign(const TileOperand &dst, const BinaryElementwiseFunction &exp) + { + TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, exp.second); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const BinaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.second); + TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, tmp1); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const BinaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.first); + TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, exp.second); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const BinaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + auto &tmp2 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.first); + op_assign(tmp2, exp.second); + TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, tmp2); + } + + // ================================================== + // Ternary elementwise functions + // ================================================== + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, exp.third); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.first); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, exp.third); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.second); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, exp.third); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.third); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, tmp1); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + auto &tmp2 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.first); + op_assign(tmp2, exp.second); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, exp.third); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + auto &tmp2 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.first); + op_assign(tmp2, exp.third); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, tmp2); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + auto &tmp2 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.second); + op_assign(tmp2, exp.third); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, tmp2); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info(), dst.tile_info(), dst.tile_info()); + auto &tmp2 = declare_temp_tile(dst.tile_info()); + auto &tmp3 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.first); + op_assign(tmp2, exp.second); + op_assign(tmp3, exp.third); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, tmp3); + } + + // ================================================== + // Assignments + // ================================================== + + /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`. + * + * The Assignment is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] exp The Assignment representing the expression to be evaluated. + */ + void op_assign(const Assignment &exp) + { + if(exp.opcode == AssignmentOp::Increment) + { + TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Add, exp.rhs); + } + else if(exp.opcode == AssignmentOp::Decrement) + { + TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Sub, exp.rhs); + } + } + + /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`. + * + * The Assignment is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @tparam TRight The type of the RHS of the assignment. + * @param[in] exp The Assignment representing the expression to be evaluated. + */ + template + void op_assign(const Assignment &exp) + { + auto &tmp1 = declare_temp_tile(exp.lhs.tile_info()); + op_assign(tmp1, exp.rhs); + op_assign(Assignment{ exp.lhs, tmp1, exp.opcode }); + } + +private: + unsigned int temp_var_counter = 0; + + /** Return the current counter value, then increment it. + * + * @return The current counter value. + */ + int next_ctr() + { + return temp_var_counter++; + } + + /** Gets the next temporary variable counter value, + * and returns a suitable temporary variable name. + * + * @return A temporary variable name. + */ + std::string next_tmp_var_name() + { + return "tmp_" + std::to_string(next_ctr()); + } + + /** Returns the argument. + * + * Used for recursion with the variadic function version of this function. + * + * @param[in] arg The TileInfo to return. + * @return The \p arg. + */ + TileInfo get_largest_size(const TileInfo &arg) + { + return arg; + } + + /** Returns a TileInfo object where the size in each dimension (width, height) is the largest + * of either TileInfo argument in the corresponding dimension. + * + * @tparam TOps Must be of TileInfo type. + * @param[in] first A TileInfo object. + * @param[in] second A TileInfo object. + * @param[in] ops A number of TileInfo objects. + * @return A TileInfo object which represents the largest shape in each dimension across the arguments. + */ + template ::value>> + TileInfo get_largest_size(const TileInfo &first, const TileInfo &second, const TOps &...ops) + { + TileInfo largest = { + first.data_type(), + std::max(first.width(), second.width()), + std::max(first.height(), second.height()) + }; + return get_largest_size(largest, ops...); + } + + /** Helper function to define a suitable TileOperand with appropriate TileInfo + * such that broadcasting is taken into account, based on the arguments provided. + * + * @tparam TArgs Must be of TileInfo type. + * @param[in] args A number of TileInfo which determine the shape of the TileOperand to declare. + * @return A newly created TileOperand. + */ + template ::value>> + TileOperand &declare_temp_tile(const TArgs &...args) + { + return TWriter::declare_tile(next_tmp_var_name().c_str(), get_largest_size(args...)); + } +}; + +} // namespace ckw + +#endif // CKW_INCLUDE_CKW_KERNELWRITERHELPER_H diff --git a/compute_kernel_writer/prototype/include/ckw/types/Functions.h b/compute_kernel_writer/prototype/include/ckw/types/Functions.h index 68146cb1c8..2dd5ed0b3d 100644 --- a/compute_kernel_writer/prototype/include/ckw/types/Functions.h +++ b/compute_kernel_writer/prototype/include/ckw/types/Functions.h @@ -37,7 +37,6 @@ enum class UnaryFunction : int32_t Sqrt = 0x0002, Erf = 0x0003, Fabs = 0x0004, - IsGreaterEqual = 0x0005, Log = 0x0006, Round = 0x0007, diff --git a/compute_kernel_writer/prototype/include/ckw/types/Operators.h b/compute_kernel_writer/prototype/include/ckw/types/Operators.h index 172650d5ae..14a88c91b4 100644 --- a/compute_kernel_writer/prototype/include/ckw/types/Operators.h +++ b/compute_kernel_writer/prototype/include/ckw/types/Operators.h @@ -68,8 +68,8 @@ enum class BinaryOp : int32_t enum class AssignmentOp : int32_t { // Unary - Increment = 0x0000, // += - Decrement = 0x0001, // -= + Increment = 0x0000, // += + Decrement = 0x0001, // -= }; } // namespace ckw diff --git a/compute_kernel_writer/prototype/src/KernelWriter.cpp b/compute_kernel_writer/prototype/src/KernelWriter.cpp index 1ac9ede5b5..9122e518b4 100644 --- a/compute_kernel_writer/prototype/src/KernelWriter.cpp +++ b/compute_kernel_writer/prototype/src/KernelWriter.cpp @@ -270,13 +270,14 @@ void KernelWriter::op_else(const std::function &body) _impl->compound_statement_end(); } -void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function &body) +void KernelWriter::op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, const TileOperand &update_var_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function &body) { auto impl_var_name = var_name.create_impl_operand(_impl.get()); auto impl_cond_value_name = cond_value_name.create_impl_operand(_impl.get()); + auto impl_update_var_name = update_var_name.create_impl_operand(_impl.get()); auto impl_update_value_name = update_value_name.create_impl_operand(_impl.get()); - _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, update_op, impl_update_value_name); + _impl->op_for_loop_header(impl_var_name, cond_op, impl_cond_value_name, impl_update_var_name, update_op, impl_update_value_name); _impl->compound_statement_begin(); body(); _impl->compound_statement_end(); diff --git a/compute_kernel_writer/prototype/src/Prototype.h b/compute_kernel_writer/prototype/src/Prototype.h index 72fa419fc2..05c7306e3a 100644 --- a/compute_kernel_writer/prototype/src/Prototype.h +++ b/compute_kernel_writer/prototype/src/Prototype.h @@ -2498,7 +2498,7 @@ public: virtual void op_else_header() = 0; - virtual void op_for_loop_header(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value, AssignmentOp update_op, const Operand &update_value) = 0; + virtual void op_for_loop_header(const Operand &var_name, BinaryOp cond_op, const Operand &cond_value, const Operand &update_var, AssignmentOp update_op, const Operand &update_value) = 0; virtual void op_load_indirect(const TensorOperand &tensor, const Operand &dst, const Operand &x, const Operand &y_indirect, const Operand &z, const Operand &b = Operand("0", OperandType::ScalarInt32)) = 0; @@ -3654,9 +3654,6 @@ public: case UnaryFunction::Fabs: _data->code += "fabs("; break; - case UnaryFunction::IsGreaterEqual: - _data->code += "isgreaterequal("; - break; case UnaryFunction::Log: _data->code += "log("; break; @@ -3798,11 +3795,12 @@ public: _data->code += "else\n"; } - void op_for_loop_header(const Operand& var_name, BinaryOp cond_op, const Operand& cond_value_name, AssignmentOp update_op, const Operand& update_value_name) override + void op_for_loop_header(const Operand& var_name, BinaryOp cond_op, const Operand& cond_value_name, const Operand &update_var_name, AssignmentOp update_op, const Operand& update_value_name) override { OperandUnpacker operands(_data->tiles, _data->arguments); const IVectorTile *var = operands.unpack(var_name); const IVectorTile *cond_value = operands.unpack(cond_value_name); + const IVectorTile *update_var = operands.unpack(update_var_name); const IVectorTile *update_value = operands.unpack(update_value_name); const int32_t dst_w = var->format().w; @@ -3818,7 +3816,7 @@ public: _data->code += " "; _data->code += to_string(cond_op); _data->code += " " + cond_value->scalar(0, 0).str + "; "; - _data->code += var->scalar(0, 0).str; + _data->code += update_var->scalar(0, 0).str; _data->code += " "; _data->code += to_string(update_op); _data->code += " " + update_value->scalar(0, 0).str + ")"; -- cgit v1.2.1