From fab6c210b37f1fa6b3e37a2583b18f8e4b5a4f12 Mon Sep 17 00:00:00 2001 From: Nikolaj Jensen Date: Tue, 27 Jun 2023 14:13:24 +0100 Subject: Design wrapper around CKW for easier writing Signed-off-by: Nikolaj Jensen Change-Id: I114cdedcaf05c6abde046741837eeb73b813aa9d Signed-off-by: Nikolaj Jensen Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/532180 Tested-by: bsgcomp Reviewed-by: Viet-Hoa Do Comments-Addressed: bsgcomp Signed-off-by: Nikolaj Jensen Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9921 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- .../prototype/include/ckw/KernelWriter.h | 5 +- .../prototype/include/ckw/KernelWriterHelper.h | 1268 ++++++++++++++++++++ .../prototype/include/ckw/types/Functions.h | 1 - .../prototype/include/ckw/types/Operators.h | 4 +- 4 files changed, 1273 insertions(+), 5 deletions(-) create mode 100644 compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h (limited to 'compute_kernel_writer/prototype/include/ckw') diff --git a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h index 146fdac53e..c116e62650 100644 --- a/compute_kernel_writer/prototype/include/ckw/KernelWriter.h +++ b/compute_kernel_writer/prototype/include/ckw/KernelWriter.h @@ -230,16 +230,17 @@ public: */ void op_else(const std::function &body); - /** Write for-loops: `for(; ; ) { body }`. + /** Write for-loops: `for(; ; ) { body }`. * * @param[in] var_name The name of the variable used in condition. * @param[in] cond_op The relational binary operator used in condition. * @param[in] cond_value_name The value which the variable is compared against. + * @param[in] update_var_name The name of the variable which is updated. * @param[in] update_op The assignment operator used for updating the update value. * @param[in, out] update_value The value which is updated at every iteration. * @param[in] body The body of the for-loop. */ - void op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function &body); + void op_for_loop(const TileOperand &var_name, BinaryOp cond_op, const TileOperand &cond_value_name, const TileOperand &update_var_name, AssignmentOp update_op, const TileOperand &update_value_name, const std::function &body); /** Write the return statement: `return;` */ diff --git a/compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h b/compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h new file mode 100644 index 0000000000..a8be859680 --- /dev/null +++ b/compute_kernel_writer/prototype/include/ckw/KernelWriterHelper.h @@ -0,0 +1,1268 @@ +/* + * Copyright (c) 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef CKW_INCLUDE_CKW_KERNELWRITERHELPER_H +#define CKW_INCLUDE_CKW_KERNELWRITERHELPER_H + +#include "ckw/KernelWriter.h" +#include "ckw/TensorOperand.h" +#include "ckw/TileOperand.h" + +#include +#include + +#include + +/* + * By including this header file you will be able to supplement the default + * Compute Kernel Writer API with additional syntax to help ease the use of CKW. + * + * To use the KernelWriterHelper you need to wrap your instance of KernelWriter + * (or any class deriving from KernelWriter): + * KernelWriterHelper writer; + * The resulting writer object comprises the original KernelWriter + * functionality (drop-in replacement), but extends the syntax as follows. + * + * Common functions/operators have natural syntax: + * 1. Unary expressions: + * writer.op_assign(dst, !src); // Logical NOT + * writer.op_assign(dst, ~src); // Bitwise NOT + * + * 2. Binary expressions: + * writer.op_assign(dst, lhs + rhs); // Addition + * writer.op_assign(dst, lhs - rhs); // Subtraction + * writer.op_assign(dst, lhs * rhs); // Multiplication + * writer.op_assign(dst, lhs / rhs); // Division + * writer.op_assign(dst, lhs % rhs); // Modulo + * writer.op_assign(dst, lhs == rhs); // Equality + * writer.op_assign(dst, lhs < rhs); // Less-than + * writer.op_assign(dst, lhs <= rhs); // Less-than-or-equal + * writer.op_assign(dst, lhs > rhs); // Greater-than + * writer.op_assign(dst, lhs >= rhs); // Greater-than-or-equal + * writer.op_assign(dst, lhs ^ rhs); // Bitwise XOR + * writer.op_assign(dst, logical_and(lhs, rhs)); // Logical AND + * writer.op_assign(dst, logical_or(lhs, rhs)); // Logical OR + * + * 3. Unary elementwise functions: + * writer.op_assign(dst, exp(src)); // Exponent + * writer.op_assign(dst, tanh(src)); // Hyperbolic tangent + * writer.op_assign(dst, sqrt(src)); // Square root + * writer.op_assign(dst, erf(src)); // Error function + * writer.op_assign(dst, fabs(src)); // Absolute of floating-point number + * writer.op_assign(dst, log(src)); // Natural logarithm + * writer.op_assign(dst, round(src)); // Round + * writer.op_assign(dst, sizeOf(src)); // sizeof + * + * 4. Binary elementwise functions: + * writer.op_assign(dst, max(first, second)); // Max + * writer.op_assign(dst, min(first, second)); // Min + * + * 5. Ternary elementwise functions: + * writer.op_assign(dst, select(first, second, third)); // Select + * + * NOTE: All the above examples support nesting, so you could write + * something like: writer.op_assign(dst, src * (log(arg) + sqrt(abs(arg))); + * + * + * 6. If-statements. The preceding syntax also allows easier writing of if-statements: + * writer.op_if(, ); + * + * For example: + * writer.op_if(exp(first_arg) == dst, [&]{ + * //... + * }).op_else_if(exp(first_arg) > dst, [&]{ + * //... + * }).op_else([&] { + * //... + * }); + * + * 7. For-loops. A similar syntax exists for for-loops: + * writer.op_for_loop(, , ); + * + * For example: + * writer.op_for_loop(index < limit, index += step, [&]{ + * //... + * }); + * + * NOTE: There are limitations on the for-loop and parameters. + * In neither the (Binary expression) or (Increment/Decrement) + * is it allowed to use nesting. For example, `(index + other) < limit` and + * `index < round(limit)` are invalid parameters. This is because the + * semantics of for-loops rely on the condition being evaluated at every iteration, + * but as temporary variables might be defined for nested expressions the semantics + * cannot be guaranteed. + */ + +namespace ckw +{ + +// ================================================== +// Type traits +// ================================================== + +/** Specifies if the type can be used as an operand for functions (e.g. max), operations (e.g. *), or assignments. */ +template +struct can_be_operand : ::std::false_type +{ +}; + +/** Specifies if the type can be assigned/written to. */ +template +struct can_be_assigned : ::std::false_type +{ +}; + +template <> +struct can_be_operand : ::std::true_type +{ +}; + +template <> +struct can_be_assigned : ::std::true_type +{ +}; + +// ================================================== +// Assignment +// ================================================== + +/** AST node for assignments. + * + * Note that \p TRight must be an operand, and \p TLeft must be assignable. + * + * @tparam TLeft The type of the destination of the assignment. + * @tparam TRight The type of the source assigned to the destination. + */ +template ::value && can_be_assigned::value>> +struct Assignment +{ + TLeft lhs; + TRight rhs; + AssignmentOp opcode; +}; + +/** Represents the expression: `\p lhs += \p rhs`. + * + * @tparam TLeft The type of the LHS of the assignment. + * @tparam TRight The type of the RHS of the assignment. + * @param[in] lhs The LHS of the assignment. + * @param[in] rhs The RHS of the assignment. + * @return The resulting AST node. + */ +template +inline Assignment operator+=(TLeft &&lhs, TRight &&rhs) +{ + return Assignment{ std::forward(lhs), std::forward(rhs), AssignmentOp::Increment }; +} + +/** Represents the expression: `\p lhs -= \p rhs`. + * + * @tparam TLeft The type of the LHS of the assignment. + * @tparam TRight The type of the RHS of the assignment. + * @param[in] lhs The LHS of the assignment. + * @param[in] rhs The RHS of the assignment. + * @return The resulting AST node. + */ +template +inline Assignment operator-=(TLeft &&lhs, TRight &&rhs) +{ + return Assignment{ std::forward(lhs), std::forward(rhs), AssignmentOp::Decrement }; +} + +// ================================================== +// Unary expression +// ================================================== + +/** AST node for unary expressions. + * + * Note that \p TSrc must be an operand. + * + * @tparam TSrc The type of the argument to the expression. + */ +template ::value>> +struct UnaryExpression +{ + TSrc src; + UnaryOp opcode; +}; + +template +struct can_be_operand> : ::std::true_type +{ +}; + +/** Represents the expression: `!\p src`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +inline UnaryExpression operator!(TSrc &&src) +{ + return UnaryExpression{ std::forward(src), UnaryOp::LogicalNot }; +} + +/** Represents the expression: `~\p src`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +inline UnaryExpression operator~(TSrc &&src) +{ + return UnaryExpression{ std::forward(src), UnaryOp::BitwiseNot }; +} + +// ================================================== +// Binary expressions +// ================================================== + +/** AST node for binary expressions. + * + * Note that both \p TLeft and \p TRight must be operands. + * + * @tparam TLeft The type of the left argument of the expression. + * @tparam TRight The type of the right argument of the expression. + */ +template ::value && can_be_operand::value>> +struct BinaryExpression +{ + TLeft lhs; + TRight rhs; + BinaryOp opcode; +}; + +template +struct can_be_operand> : ::std::true_type +{ +}; + +/** Represents the expression: `\p lhs + \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator+(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Add }; +} + +/** Represents the expression: `\p lhs - \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator-(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Sub }; +} + +/** Represents the expression: `\p lhs * \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator*(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Mul }; +} + +/** Represents the expression: `\p lhs / \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator/(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Div }; +} + +/** Represents the expression: `\p lhs % \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator%(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Mod }; +} + +/** Represents the expression: `\p lhs == \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator==(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Equal }; +} + +/** Represents the expression: `\p lhs < \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator<(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Less }; +} + +/** Represents the expression: `\p lhs <= \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator<=(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::LessEqual }; +} + +/** Represents the expression: `\p lhs > \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator>(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::Greater }; +} + +/** Represents the expression: `\p lhs >= \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator>=(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::GreaterEqual }; +} + +/** Represents the expression: `\p lhs ^ \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression operator^(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::BitwiseXOR }; +} + +/** Represents the expression: `\p lhs && \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression logical_and(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::LogicalAnd }; +} + +/** Represents the expression: `\p lhs && \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression, TOps...> logical_and(TLeft &&lhs, TRight &&rhs, TOps &&...ops) +{ + return logical_and( + BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::LogicalAnd }, + std::forward(ops)...); +} + +/** Represents the expression: `\p lhs || \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression logical_or(TLeft &&lhs, TRight &&rhs) +{ + return BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::LogicalOr }; +} + +/** Represents the expression: `\p lhs || \p rhs`. + * + * @tparam TLeft The type of the LHS of the expression. + * @tparam TRight The type of the RHS of the expression. + * @param[in] lhs The LHS of the expression. + * @param[in] rhs The RHS of the expression. + * @return The resulting AST node. + */ +template +inline BinaryExpression, TOps...> logical_or(TLeft &&lhs, TRight &&rhs, TOps &&...ops) +{ + return logical_or( + BinaryExpression{ std::forward(lhs), std::forward(rhs), BinaryOp::LogicalOr }, + std::forward(ops)...); +} + +// ================================================== +// Unary elementwise functions +// ================================================== + +/** AST node for unary elementwise functions. + * + * Note that \p TSrc must be an operand. + * + * @tparam TSrc The type of the argument to the function. + */ +template ::value>> +struct UnaryElementwiseFunction +{ + TSrc src; + UnaryFunction opcode; +}; + +template +struct can_be_operand> : ::std::true_type +{ +}; + +/** Represents the expression: `exp(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction exp(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Exp }; +} + +/** Represents the expression: `tanh(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction tanh(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Tanh }; +} + +/** Represents the expression: `sqrt(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction sqrt(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Sqrt }; +} + +/** Represents the expression: `erf(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction erf(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Erf }; +} + +/** Represents the expression: `fabs(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction fabs(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Fabs }; +} + +/** Represents the expression: `log(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction log(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Log }; +} + +/** Represents the expression: `round(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction round(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::Round }; +} + +/** Represents the expression: `sizeof(\p src)`. + * + * @tparam TSrc The type of the argument. + * @param[in] src The argument. + * @return The resulting AST node. + */ +template +UnaryElementwiseFunction sizeOf(TSrc &&src) +{ + return UnaryElementwiseFunction{ std::forward(src), UnaryFunction::SizeOf }; +} + +// ================================================== +// Binary elementwise functions +// ================================================== + +/** AST node for binary elementwise functions. + * + * Note that both \p TFirst and \p TSecond must be operands. + * + * @tparam TFirst The type of the left argument of the function. + * @tparam TSecond The type of the right argument of the function. + */ +template ::value && can_be_operand::value>> +struct BinaryElementwiseFunction +{ + TFirst first; + TSecond second; + BinaryFunction opcode; +}; + +template +struct can_be_operand> : ::std::true_type +{ +}; + +/** Represents the function call: `max(\p first, \p second)`. + * + * @tparam TFirst The type of the first argument. + * @tparam TSecond The type of the second argument. + * @param[in] first The first argument. + * @param[in] second The second argument. + * @return The resulting AST node. + */ +template +BinaryElementwiseFunction max(TFirst &&first, TSecond &&second) +{ + return BinaryElementwiseFunction{ std::forward(first), std::forward(second), BinaryFunction::Max }; +} + +/** Represents the function call: `min(\p first, \p second)`. + * + * @tparam TFirst The type of the first argument. + * @tparam TSecond The type of the second argument. + * @param[in] first The first argument. + * @param[in] second The second argument. + * @return The resulting AST node. + */ +template +BinaryElementwiseFunction min(TFirst &&first, TSecond &&second) +{ + return BinaryElementwiseFunction{ std::forward(first), std::forward(second), BinaryFunction::Min }; +} + +// ================================================== +// Ternary elementwise functions +// ================================================== + +/** AST node for ternary elementwise functions. + * + * Note that \p TFirst, \p TSecond, and \p TThird all must be operands. + * + * @tparam TFirst The type of the first argument to the function. + * @tparam TSecond The type of the second argument to the function. + * @tparam TThird The type of the third argument to the function. + */ +template ::value && can_be_operand::value && can_be_operand::value>> +struct TernaryElementwiseFunction +{ + TFirst first; + TSecond second; + TThird third; + TernaryFunction opcode; +}; + +template +struct can_be_operand> : ::std::true_type +{ +}; + +/** Represents the function call: `select(\p first, \p second, \p third)`. + * + * @tparam TFirst The type of the first argument. + * @tparam TSecond The type of the second argument. + * @tparam TThird The type of the third argument. + * @param[in] first The first argument. + * @param[in] second The second argument. + * @param[in] third The third argument. + * @return The resulting AST node. + */ +template +TernaryElementwiseFunction select(TFirst &&first, TSecond &&second, TThird &&third) +{ + return TernaryElementwiseFunction{ std::forward(first), std::forward(second), std::forward(third), TernaryFunction::Select }; +} + +/** Helper class used to extend a KernelWriter with additional functionality + * in order to make writing easier. + * + * This extension automatically handles creation of temporary variables, and + * allows nested function calls and operations. + * + * @tparam TWriter The type of KernelWriter to be overloaded. This must inherit from KernelWriter. + */ +template ::value>> +class KernelWriterHelper : public TWriter +{ +public: + using TWriter::TWriter; + + // ================================================== + // If-statements + // ================================================== + + // Un-hide original implementation, in case the original implementation is required. + using TWriter::op_if; + + /** Represents the if-statement: `if(\p cond) { \p body }`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] body The body of the if-statement. + */ + KernelWriterHelper &op_if(const BinaryExpression &cond, const std::function &body) + { + TWriter::op_if(cond.lhs, cond.opcode, cond.rhs, body); + return *this; + } + + /** Represents the if-statement: `if(\p cond) { \p body }`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] body The body of the if-statement. + */ + template + KernelWriterHelper &op_if(const BinaryExpression &cond, const std::function &body) + { + auto &tmp1 = declare_temp_tile(cond.lhs.tile_info()); + op_assign(tmp1, cond.rhs); + TWriter::op_if(cond.lhs, cond.opcode, tmp1, body); + return *this; + } + + /** Represents the if-statement: `if(\p cond) { \p body }`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] body The body of the if-statement. + */ + template + KernelWriterHelper &op_if(const BinaryExpression &cond, const std::function &body) + { + auto &tmp1 = declare_temp_tile(cond.rhs.tile_info()); + op_assign(tmp1, cond.lhs); + TWriter::op_if(tmp1, cond.opcode, cond.rhs, body); + return *this; + } + + // Un-hide original implementation, in case the original implementation is required. + using TWriter::op_else_if; + + /** Represents the else-if-statement: `else if(\p cond) { \p body }`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] body The body of the else-if-statement. + */ + KernelWriterHelper &op_else_if(const BinaryExpression &cond, const std::function &body) + { + TWriter::op_else_if(cond.lhs, cond.opcode, cond.rhs, body); + return *this; + } + + /** Represents the else-if-statement: `else if(\p cond) { \p body }`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] body The body of the else-if-statement. + */ + template + KernelWriterHelper &op_else_if(const BinaryExpression &cond, const std::function &body) + { + auto &tmp1 = declare_temp_tile(cond.lhs.tile_info()); + op_assign(tmp1, cond.rhs); + TWriter::op_else_if(cond.lhs, cond.opcode, tmp1, body); + return *this; + } + + /** Represents the else-if-statement: `else if(\p cond) { \p body }`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] body The body of the else-if-statement. + */ + template + KernelWriterHelper &op_else_if(const BinaryExpression &cond, const std::function &body) + { + auto &tmp1 = declare_temp_tile(cond.rhs.tile_info()); + op_assign(tmp1, cond.lhs); + TWriter::op_else_if(tmp1, cond.opcode, cond.rhs, body); + return *this; + } + + // ================================================== + // For-loops + // ================================================== + + // Un-hide original implementation, in case the original implementation is required. + using TWriter::op_for_loop; + + /** Represents the for-loop: `for(;\p cond; \p updater) { \p body }`. + * + * The BinaryExpression for the condition and the Assignment + * for the updater are unpacked and their components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] cond The BinaryExpression representing the condition. + * @param[in] updater The Assignment representing the updater. + * @param[in] body The body of the for-loop. + */ + void op_for_loop(const BinaryExpression &cond, const Assignment &updater, const std::function &body) + { + TWriter::op_for_loop(cond.lhs, cond.opcode, cond.rhs, updater.lhs, updater.opcode, updater.rhs, body); + } + + // ================================================== + // Unary expressions + // ================================================== + + // Un-hide original implementation, in case the original implementation is required. + using TWriter::op_assign; + + /** Represents the assignment: `\p dst = \p exp`. + * + * The UnaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The UnaryExpression representing the expression to be evaluated and assigned. + */ + void op_assign(const TileOperand &dst, const UnaryExpression &exp) + { + TWriter::op_unary_expression(dst, exp.opcode, exp.src); + } + + // ================================================== + // Binary expressions + // ================================================== + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned. + */ + void op_assign(const TileOperand &dst, const BinaryExpression &exp) + { + TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, exp.rhs); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const BinaryExpression &exp) + { + std::cout << "Beginning assignment!" << std::endl; + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.rhs); + TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, tmp1); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const BinaryExpression &exp) + { + std::cout << "Beginning assignment!" << std::endl; + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.lhs); + TWriter::op_binary_expression(dst, tmp1, exp.opcode, exp.rhs); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryExpression is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const BinaryExpression &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + auto &tmp2 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.lhs); + op_assign(tmp2, exp.rhs); + TWriter::op_binary_expression(dst, tmp1, exp.opcode, tmp2); + } + + // ================================================== + // Unary elementwise functions + // ================================================== + + /** Represents the assignment: `\p dst = \p exp`. + * + * The UnaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + void op_assign(const TileOperand &dst, const UnaryElementwiseFunction &exp) + { + TWriter::op_unary_elementwise_function(dst, exp.opcode, exp.src); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The UnaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const UnaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.lhs); + TWriter::op_unary_elementwise_function(dst, exp.opcode, tmp1); + } + + // ================================================== + // Binary elementwise functions + // ================================================== + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + void op_assign(const TileOperand &dst, const BinaryElementwiseFunction &exp) + { + TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, exp.second); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const BinaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.second); + TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, tmp1); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const BinaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.first); + TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, exp.second); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The BinaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const BinaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + auto &tmp2 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.first); + op_assign(tmp2, exp.second); + TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, tmp2); + } + + // ================================================== + // Ternary elementwise functions + // ================================================== + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, exp.third); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.first); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, exp.third); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.second); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, exp.third); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.third); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, tmp1); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + auto &tmp2 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.first); + op_assign(tmp2, exp.second); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, exp.third); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + auto &tmp2 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.first); + op_assign(tmp2, exp.third); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, tmp2); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info()); + auto &tmp2 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.second); + op_assign(tmp2, exp.third); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, tmp2); + } + + /** Represents the assignment: `\p dst = \p exp`. + * + * The TernaryElementwiseFunction is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] dst The tile which is assigned to. + * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. + */ + template + void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) + { + auto &tmp1 = declare_temp_tile(dst.tile_info(), dst.tile_info(), dst.tile_info()); + auto &tmp2 = declare_temp_tile(dst.tile_info()); + auto &tmp3 = declare_temp_tile(dst.tile_info()); + op_assign(tmp1, exp.first); + op_assign(tmp2, exp.second); + op_assign(tmp3, exp.third); + TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, tmp3); + } + + // ================================================== + // Assignments + // ================================================== + + /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`. + * + * The Assignment is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @param[in] exp The Assignment representing the expression to be evaluated. + */ + void op_assign(const Assignment &exp) + { + if(exp.opcode == AssignmentOp::Increment) + { + TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Add, exp.rhs); + } + else if(exp.opcode == AssignmentOp::Decrement) + { + TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Sub, exp.rhs); + } + } + + /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`. + * + * The Assignment is unpacked and its components are forwarded to + * the underlying KernelWriter's implementation. + * + * @tparam TRight The type of the RHS of the assignment. + * @param[in] exp The Assignment representing the expression to be evaluated. + */ + template + void op_assign(const Assignment &exp) + { + auto &tmp1 = declare_temp_tile(exp.lhs.tile_info()); + op_assign(tmp1, exp.rhs); + op_assign(Assignment{ exp.lhs, tmp1, exp.opcode }); + } + +private: + unsigned int temp_var_counter = 0; + + /** Return the current counter value, then increment it. + * + * @return The current counter value. + */ + int next_ctr() + { + return temp_var_counter++; + } + + /** Gets the next temporary variable counter value, + * and returns a suitable temporary variable name. + * + * @return A temporary variable name. + */ + std::string next_tmp_var_name() + { + return "tmp_" + std::to_string(next_ctr()); + } + + /** Returns the argument. + * + * Used for recursion with the variadic function version of this function. + * + * @param[in] arg The TileInfo to return. + * @return The \p arg. + */ + TileInfo get_largest_size(const TileInfo &arg) + { + return arg; + } + + /** Returns a TileInfo object where the size in each dimension (width, height) is the largest + * of either TileInfo argument in the corresponding dimension. + * + * @tparam TOps Must be of TileInfo type. + * @param[in] first A TileInfo object. + * @param[in] second A TileInfo object. + * @param[in] ops A number of TileInfo objects. + * @return A TileInfo object which represents the largest shape in each dimension across the arguments. + */ + template ::value>> + TileInfo get_largest_size(const TileInfo &first, const TileInfo &second, const TOps &...ops) + { + TileInfo largest = { + first.data_type(), + std::max(first.width(), second.width()), + std::max(first.height(), second.height()) + }; + return get_largest_size(largest, ops...); + } + + /** Helper function to define a suitable TileOperand with appropriate TileInfo + * such that broadcasting is taken into account, based on the arguments provided. + * + * @tparam TArgs Must be of TileInfo type. + * @param[in] args A number of TileInfo which determine the shape of the TileOperand to declare. + * @return A newly created TileOperand. + */ + template ::value>> + TileOperand &declare_temp_tile(const TArgs &...args) + { + return TWriter::declare_tile(next_tmp_var_name().c_str(), get_largest_size(args...)); + } +}; + +} // namespace ckw + +#endif // CKW_INCLUDE_CKW_KERNELWRITERHELPER_H diff --git a/compute_kernel_writer/prototype/include/ckw/types/Functions.h b/compute_kernel_writer/prototype/include/ckw/types/Functions.h index 68146cb1c8..2dd5ed0b3d 100644 --- a/compute_kernel_writer/prototype/include/ckw/types/Functions.h +++ b/compute_kernel_writer/prototype/include/ckw/types/Functions.h @@ -37,7 +37,6 @@ enum class UnaryFunction : int32_t Sqrt = 0x0002, Erf = 0x0003, Fabs = 0x0004, - IsGreaterEqual = 0x0005, Log = 0x0006, Round = 0x0007, diff --git a/compute_kernel_writer/prototype/include/ckw/types/Operators.h b/compute_kernel_writer/prototype/include/ckw/types/Operators.h index 172650d5ae..14a88c91b4 100644 --- a/compute_kernel_writer/prototype/include/ckw/types/Operators.h +++ b/compute_kernel_writer/prototype/include/ckw/types/Operators.h @@ -68,8 +68,8 @@ enum class BinaryOp : int32_t enum class AssignmentOp : int32_t { // Unary - Increment = 0x0000, // += - Decrement = 0x0001, // -= + Increment = 0x0000, // += + Decrement = 0x0001, // -= }; } // namespace ckw -- cgit v1.2.1