/* * Copyright (c) 2023 Arm Limited. * * SPDX-License-Identifier: MIT * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to * deal in the Software without restriction, including without limitation the * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or * sell copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #ifndef CKW_INCLUDE_CKW_KERNELWRITERHELPER_H #define CKW_INCLUDE_CKW_KERNELWRITERHELPER_H #include "ckw/KernelWriter.h" #include "ckw/TensorOperand.h" #include "ckw/TileOperand.h" #include #include /* * By including this header file you will be able to supplement the default * Compute Kernel Writer API with additional syntax to help ease the use of CKW. * * To use the KernelWriterHelper you need to wrap your instance of KernelWriter * (or any class deriving from KernelWriter): * KernelWriterHelper writer; * The resulting writer object comprises the original KernelWriter * functionality (drop-in replacement), but extends the syntax as follows. * * Common functions/operators have natural syntax: * 1. Unary expressions: * writer.op_assign(dst, !src); // Logical NOT * writer.op_assign(dst, ~src); // Bitwise NOT * * 2. Binary expressions: * writer.op_assign(dst, lhs + rhs); // Addition * writer.op_assign(dst, lhs - rhs); // Subtraction * writer.op_assign(dst, lhs * rhs); // Multiplication * writer.op_assign(dst, lhs / rhs); // Division * writer.op_assign(dst, lhs % rhs); // Modulo * writer.op_assign(dst, lhs == rhs); // Equality * writer.op_assign(dst, lhs < rhs); // Less-than * writer.op_assign(dst, lhs <= rhs); // Less-than-or-equal * writer.op_assign(dst, lhs > rhs); // Greater-than * writer.op_assign(dst, lhs >= rhs); // Greater-than-or-equal * writer.op_assign(dst, lhs ^ rhs); // Bitwise XOR * writer.op_assign(dst, logical_and(lhs, rhs)); // Logical AND * writer.op_assign(dst, logical_or(lhs, rhs)); // Logical OR * * 3. Unary elementwise functions: * writer.op_assign(dst, exp(src)); // Exponent * writer.op_assign(dst, tanh(src)); // Hyperbolic tangent * writer.op_assign(dst, sqrt(src)); // Square root * writer.op_assign(dst, erf(src)); // Error function * writer.op_assign(dst, fabs(src)); // Absolute of floating-point number * writer.op_assign(dst, log(src)); // Natural logarithm * writer.op_assign(dst, round(src)); // Round * writer.op_assign(dst, sizeOf(src)); // sizeof * * 4. Binary elementwise functions: * writer.op_assign(dst, max(first, second)); // Max * writer.op_assign(dst, min(first, second)); // Min * * 5. Ternary elementwise functions: * writer.op_assign(dst, select(first, second, third)); // Select * * NOTE: All the above examples support nesting, so you could write * something like: writer.op_assign(dst, src * (log(arg) + sqrt(abs(arg))); * * * 6. If-statements. The preceding syntax also allows easier writing of if-statements: * writer.op_if(, ); * * For example: * writer.op_if(exp(first_arg) == dst, [&]{ * //... * }).op_else_if(exp(first_arg) > dst, [&]{ * //... * }).op_else([&] { * //... * }); * * 7. For-loops. A similar syntax exists for for-loops: * writer.op_for_loop(, , ); * * For example: * writer.op_for_loop(index < limit, index += step, [&]{ * //... * }); * * NOTE: There are limitations on the for-loop and parameters. * In neither the (Binary expression) or (Increment/Decrement) * is it allowed to use nesting. For example, `(index + other) < limit` and * `index < round(limit)` are invalid parameters. This is because the * semantics of for-loops rely on the condition being evaluated at every iteration, * but as temporary variables might be defined for nested expressions the semantics * cannot be guaranteed. */ namespace ckw { // ================================================== // Type traits // ================================================== /** Specifies if the type can be used as an operand for functions (e.g. max), operations (e.g. *), or assignments. */ template struct can_be_operand : ::std::false_type { }; /** Specifies if the type can be assigned/written to. */ template struct can_be_assigned : ::std::false_type { }; template <> struct can_be_operand : ::std::true_type { }; template <> struct can_be_assigned : ::std::true_type { }; // ================================================== // Assignment // ================================================== /** AST node for assignments. * * Note that \p TRight must be an operand, and \p TLeft must be assignable. * * @tparam TLeft The type of the destination of the assignment. * @tparam TRight The type of the source assigned to the destination. */ template ::value && can_be_assigned::value>> struct Assignment { TLeft lhs; TRight rhs; AssignmentOp opcode; }; /** Represents the expression: `\p lhs += \p rhs`. * * @tparam TLeft The type of the LHS of the assignment. * @tparam TRight The type of the RHS of the assignment. * @param[in] lhs The LHS of the assignment. * @param[in] rhs The RHS of the assignment. * @return The resulting AST node. */ template inline Assignment operator+=(TLeft &&lhs, TRight &&rhs) { return Assignment{std::forward(lhs), std::forward(rhs), AssignmentOp::Increment}; } /** Represents the expression: `\p lhs -= \p rhs`. * * @tparam TLeft The type of the LHS of the assignment. * @tparam TRight The type of the RHS of the assignment. * @param[in] lhs The LHS of the assignment. * @param[in] rhs The RHS of the assignment. * @return The resulting AST node. */ template inline Assignment operator-=(TLeft &&lhs, TRight &&rhs) { return Assignment{std::forward(lhs), std::forward(rhs), AssignmentOp::Decrement}; } // ================================================== // Unary expression // ================================================== /** AST node for unary expressions. * * Note that \p TSrc must be an operand. * * @tparam TSrc The type of the argument to the expression. */ template ::value>> struct UnaryExpression { TSrc src; UnaryOp opcode; }; template struct can_be_operand> : ::std::true_type { }; /** Represents the expression: `!\p src`. * * @tparam TSrc The type of the argument. * @param[in] src The argument. * @return The resulting AST node. */ template inline UnaryExpression operator!(TSrc &&src) { return UnaryExpression{std::forward(src), UnaryOp::LogicalNot}; } /** Represents the expression: `~\p src`. * * @tparam TSrc The type of the argument. * @param[in] src The argument. * @return The resulting AST node. */ template inline UnaryExpression operator~(TSrc &&src) { return UnaryExpression{std::forward(src), UnaryOp::BitwiseNot}; } // ================================================== // Binary expressions // ================================================== /** AST node for binary expressions. * * Note that both \p TLeft and \p TRight must be operands. * * @tparam TLeft The type of the left argument of the expression. * @tparam TRight The type of the right argument of the expression. */ template ::value && can_be_operand::value>> struct BinaryExpression { TLeft lhs; TRight rhs; BinaryOp opcode; }; template struct can_be_operand> : ::std::true_type { }; /** Represents the expression: `\p lhs + \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression operator+(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::Add}; } /** Represents the expression: `\p lhs - \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression operator-(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::Sub}; } /** Represents the expression: `\p lhs * \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression operator*(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::Mul}; } /** Represents the expression: `\p lhs / \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression operator/(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::Div}; } /** Represents the expression: `\p lhs % \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression operator%(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::Mod}; } /** Represents the expression: `\p lhs == \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression operator==(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::Equal}; } /** Represents the expression: `\p lhs < \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression operator<(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::Less}; } /** Represents the expression: `\p lhs <= \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression operator<=(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::LessEqual}; } /** Represents the expression: `\p lhs > \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression operator>(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::Greater}; } /** Represents the expression: `\p lhs >= \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression operator>=(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::GreaterEqual}; } /** Represents the expression: `\p lhs ^ \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression operator^(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::BitwiseXOR}; } /** Represents the expression: `\p lhs && \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression logical_and(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::LogicalAnd}; } /** Represents the expression: `\p lhs && \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression, TOps...> logical_and(TLeft &&lhs, TRight &&rhs, TOps &&...ops) { return logical_and( BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::LogicalAnd}, std::forward(ops)...); } /** Represents the expression: `\p lhs || \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression logical_or(TLeft &&lhs, TRight &&rhs) { return BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::LogicalOr}; } /** Represents the expression: `\p lhs || \p rhs`. * * @tparam TLeft The type of the LHS of the expression. * @tparam TRight The type of the RHS of the expression. * @param[in] lhs The LHS of the expression. * @param[in] rhs The RHS of the expression. * @return The resulting AST node. */ template inline BinaryExpression, TOps...> logical_or(TLeft &&lhs, TRight &&rhs, TOps &&...ops) { return logical_or( BinaryExpression{std::forward(lhs), std::forward(rhs), BinaryOp::LogicalOr}, std::forward(ops)...); } // ================================================== // Unary elementwise functions // ================================================== /** AST node for unary elementwise functions. * * Note that \p TSrc must be an operand. * * @tparam TSrc The type of the argument to the function. */ template ::value>> struct UnaryElementwiseFunction { TSrc src; UnaryFunction opcode; }; template struct can_be_operand> : ::std::true_type { }; /** Represents the expression: `exp(\p src)`. * * @tparam TSrc The type of the argument. * @param[in] src The argument. * @return The resulting AST node. */ template UnaryElementwiseFunction exp(TSrc &&src) { return UnaryElementwiseFunction{std::forward(src), UnaryFunction::Exp}; } /** Represents the expression: `tanh(\p src)`. * * @tparam TSrc The type of the argument. * @param[in] src The argument. * @return The resulting AST node. */ template UnaryElementwiseFunction tanh(TSrc &&src) { return UnaryElementwiseFunction{std::forward(src), UnaryFunction::Tanh}; } /** Represents the expression: `sqrt(\p src)`. * * @tparam TSrc The type of the argument. * @param[in] src The argument. * @return The resulting AST node. */ template UnaryElementwiseFunction sqrt(TSrc &&src) { return UnaryElementwiseFunction{std::forward(src), UnaryFunction::Sqrt}; } /** Represents the expression: `erf(\p src)`. * * @tparam TSrc The type of the argument. * @param[in] src The argument. * @return The resulting AST node. */ template UnaryElementwiseFunction erf(TSrc &&src) { return UnaryElementwiseFunction{std::forward(src), UnaryFunction::Erf}; } /** Represents the expression: `fabs(\p src)`. * * @tparam TSrc The type of the argument. * @param[in] src The argument. * @return The resulting AST node. */ template UnaryElementwiseFunction fabs(TSrc &&src) { return UnaryElementwiseFunction{std::forward(src), UnaryFunction::Fabs}; } /** Represents the expression: `log(\p src)`. * * @tparam TSrc The type of the argument. * @param[in] src The argument. * @return The resulting AST node. */ template UnaryElementwiseFunction log(TSrc &&src) { return UnaryElementwiseFunction{std::forward(src), UnaryFunction::Log}; } /** Represents the expression: `round(\p src)`. * * @tparam TSrc The type of the argument. * @param[in] src The argument. * @return The resulting AST node. */ template UnaryElementwiseFunction round(TSrc &&src) { return UnaryElementwiseFunction{std::forward(src), UnaryFunction::Round}; } /** Represents the expression: `sizeof(\p src)`. * * @tparam TSrc The type of the argument. * @param[in] src The argument. * @return The resulting AST node. */ template UnaryElementwiseFunction sizeOf(TSrc &&src) { return UnaryElementwiseFunction{std::forward(src), UnaryFunction::SizeOf}; } // ================================================== // Binary elementwise functions // ================================================== /** AST node for binary elementwise functions. * * Note that both \p TFirst and \p TSecond must be operands. * * @tparam TFirst The type of the left argument of the function. * @tparam TSecond The type of the right argument of the function. */ template ::value && can_be_operand::value>> struct BinaryElementwiseFunction { TFirst first; TSecond second; BinaryFunction opcode; }; template struct can_be_operand> : ::std::true_type { }; /** Represents the function call: `max(\p first, \p second)`. * * @tparam TFirst The type of the first argument. * @tparam TSecond The type of the second argument. * @param[in] first The first argument. * @param[in] second The second argument. * @return The resulting AST node. */ template BinaryElementwiseFunction max(TFirst &&first, TSecond &&second) { return BinaryElementwiseFunction{std::forward(first), std::forward(second), BinaryFunction::Max}; } /** Represents the function call: `min(\p first, \p second)`. * * @tparam TFirst The type of the first argument. * @tparam TSecond The type of the second argument. * @param[in] first The first argument. * @param[in] second The second argument. * @return The resulting AST node. */ template BinaryElementwiseFunction min(TFirst &&first, TSecond &&second) { return BinaryElementwiseFunction{std::forward(first), std::forward(second), BinaryFunction::Min}; } // ================================================== // Ternary elementwise functions // ================================================== /** AST node for ternary elementwise functions. * * Note that \p TFirst, \p TSecond, and \p TThird all must be operands. * * @tparam TFirst The type of the first argument to the function. * @tparam TSecond The type of the second argument to the function. * @tparam TThird The type of the third argument to the function. */ template ::value && can_be_operand::value && can_be_operand::value>> struct TernaryElementwiseFunction { TFirst first; TSecond second; TThird third; TernaryFunction opcode; }; template struct can_be_operand> : ::std::true_type { }; /** Represents the function call: `select(\p first, \p second, \p third)`. * * @tparam TFirst The type of the first argument. * @tparam TSecond The type of the second argument. * @tparam TThird The type of the third argument. * @param[in] first The first argument. * @param[in] second The second argument. * @param[in] third The third argument. * @return The resulting AST node. */ template TernaryElementwiseFunction select(TFirst &&first, TSecond &&second, TThird &&third) { return TernaryElementwiseFunction{std::forward(first), std::forward(second), std::forward(third), TernaryFunction::Select}; } /** Helper class used to extend a KernelWriter with additional functionality * in order to make writing easier. * * This extension automatically handles creation of temporary variables, and * allows nested function calls and operations. * * @tparam TWriter The type of KernelWriter to be overloaded. This must inherit from KernelWriter. */ template ::value>> class KernelWriterHelper : public TWriter { public: using TWriter::TWriter; // ================================================== // If-statements // ================================================== // Un-hide original implementation, in case the original implementation is required. using TWriter::op_if; /** Represents the if-statement: `if(\p cond) { \p body }`. * * The BinaryExpression is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] cond The BinaryExpression representing the condition. * @param[in] body The body of the if-statement. */ KernelWriterHelper &op_if(const BinaryExpression &cond, const std::function &body) { TWriter::op_if(cond.lhs, cond.opcode, cond.rhs, body); return *this; } /** Represents the if-statement: `if(\p cond) { \p body }`. * * The BinaryExpression is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] cond The BinaryExpression representing the condition. * @param[in] body The body of the if-statement. */ template KernelWriterHelper &op_if(const BinaryExpression &cond, const std::function &body) { auto &tmp1 = declare_temp_tile(cond.lhs.tile_info()); op_assign(tmp1, cond.rhs); TWriter::op_if(cond.lhs, cond.opcode, tmp1, body); return *this; } /** Represents the if-statement: `if(\p cond) { \p body }`. * * The BinaryExpression is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] cond The BinaryExpression representing the condition. * @param[in] body The body of the if-statement. */ template KernelWriterHelper &op_if(const BinaryExpression &cond, const std::function &body) { auto &tmp1 = declare_temp_tile(cond.rhs.tile_info()); op_assign(tmp1, cond.lhs); TWriter::op_if(tmp1, cond.opcode, cond.rhs, body); return *this; } // Un-hide original implementation, in case the original implementation is required. using TWriter::op_else_if; /** Represents the else-if-statement: `else if(\p cond) { \p body }`. * * The BinaryExpression is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] cond The BinaryExpression representing the condition. * @param[in] body The body of the else-if-statement. */ KernelWriterHelper &op_else_if(const BinaryExpression &cond, const std::function &body) { TWriter::op_else_if(cond.lhs, cond.opcode, cond.rhs, body); return *this; } /** Represents the else-if-statement: `else if(\p cond) { \p body }`. * * The BinaryExpression is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] cond The BinaryExpression representing the condition. * @param[in] body The body of the else-if-statement. */ template KernelWriterHelper &op_else_if(const BinaryExpression &cond, const std::function &body) { auto &tmp1 = declare_temp_tile(cond.lhs.tile_info()); op_assign(tmp1, cond.rhs); TWriter::op_else_if(cond.lhs, cond.opcode, tmp1, body); return *this; } /** Represents the else-if-statement: `else if(\p cond) { \p body }`. * * The BinaryExpression is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] cond The BinaryExpression representing the condition. * @param[in] body The body of the else-if-statement. */ template KernelWriterHelper &op_else_if(const BinaryExpression &cond, const std::function &body) { auto &tmp1 = declare_temp_tile(cond.rhs.tile_info()); op_assign(tmp1, cond.lhs); TWriter::op_else_if(tmp1, cond.opcode, cond.rhs, body); return *this; } // ================================================== // For-loops // ================================================== // Un-hide original implementation, in case the original implementation is required. using TWriter::op_for_loop; /** Represents the for-loop: `for(;\p cond; \p updater) { \p body }`. * * The BinaryExpression for the condition and the Assignment * for the updater are unpacked and their components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] cond The BinaryExpression representing the condition. * @param[in] updater The Assignment representing the updater. * @param[in] body The body of the for-loop. */ void op_for_loop(const BinaryExpression &cond, const Assignment &updater, const std::function &body) { TWriter::op_for_loop(cond.lhs, cond.opcode, cond.rhs, updater.lhs, updater.opcode, updater.rhs, body); } // ================================================== // Unary expressions // ================================================== // Un-hide original implementation, in case the original implementation is required. using TWriter::op_assign; /** Represents the assignment: `\p dst = \p exp`. * * The UnaryExpression is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The UnaryExpression representing the expression to be evaluated and assigned. */ void op_assign(const TileOperand &dst, const UnaryExpression &exp) { TWriter::op_unary_expression(dst, exp.opcode, exp.src); } // ================================================== // Binary expressions // ================================================== /** Represents the assignment: `\p dst = \p exp`. * * The BinaryExpression is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned. */ void op_assign(const TileOperand &dst, const BinaryExpression &exp) { TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, exp.rhs); } /** Represents the assignment: `\p dst = \p exp`. * * The BinaryExpression is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const BinaryExpression &exp) { std::cout << "Beginning assignment!" << std::endl; auto &tmp1 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.rhs); TWriter::op_binary_expression(dst, exp.lhs, exp.opcode, tmp1); } /** Represents the assignment: `\p dst = \p exp`. * * The BinaryExpression is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const BinaryExpression &exp) { std::cout << "Beginning assignment!" << std::endl; auto &tmp1 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.lhs); TWriter::op_binary_expression(dst, tmp1, exp.opcode, exp.rhs); } /** Represents the assignment: `\p dst = \p exp`. * * The BinaryExpression is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The BinaryExpression representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const BinaryExpression &exp) { auto &tmp1 = declare_temp_tile(dst.tile_info()); auto &tmp2 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.lhs); op_assign(tmp2, exp.rhs); TWriter::op_binary_expression(dst, tmp1, exp.opcode, tmp2); } // ================================================== // Unary elementwise functions // ================================================== /** Represents the assignment: `\p dst = \p exp`. * * The UnaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned. */ void op_assign(const TileOperand &dst, const UnaryElementwiseFunction &exp) { TWriter::op_unary_elementwise_function(dst, exp.opcode, exp.src); } /** Represents the assignment: `\p dst = \p exp`. * * The UnaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The UnaryElementwiseFunction representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const UnaryElementwiseFunction &exp) { auto &tmp1 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.lhs); TWriter::op_unary_elementwise_function(dst, exp.opcode, tmp1); } // ================================================== // Binary elementwise functions // ================================================== /** Represents the assignment: `\p dst = \p exp`. * * The BinaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned. */ void op_assign(const TileOperand &dst, const BinaryElementwiseFunction &exp) { TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, exp.second); } /** Represents the assignment: `\p dst = \p exp`. * * The BinaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const BinaryElementwiseFunction &exp) { auto &tmp1 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.second); TWriter::op_binary_elementwise_function(dst, exp.opcode, exp.first, tmp1); } /** Represents the assignment: `\p dst = \p exp`. * * The BinaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const BinaryElementwiseFunction &exp) { auto &tmp1 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.first); TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, exp.second); } /** Represents the assignment: `\p dst = \p exp`. * * The BinaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The BinaryElementwiseFunction representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const BinaryElementwiseFunction &exp) { auto &tmp1 = declare_temp_tile(dst.tile_info()); auto &tmp2 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.first); op_assign(tmp2, exp.second); TWriter::op_binary_elementwise_function(dst, exp.opcode, tmp1, tmp2); } // ================================================== // Ternary elementwise functions // ================================================== /** Represents the assignment: `\p dst = \p exp`. * * The TernaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. */ void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) { TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, exp.third); } /** Represents the assignment: `\p dst = \p exp`. * * The TernaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) { auto &tmp1 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.first); TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, exp.third); } /** Represents the assignment: `\p dst = \p exp`. * * The TernaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) { auto &tmp1 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.second); TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, exp.third); } /** Represents the assignment: `\p dst = \p exp`. * * The TernaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) { auto &tmp1 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.third); TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, exp.second, tmp1); } /** Represents the assignment: `\p dst = \p exp`. * * The TernaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) { auto &tmp1 = declare_temp_tile(dst.tile_info()); auto &tmp2 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.first); op_assign(tmp2, exp.second); TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, exp.third); } /** Represents the assignment: `\p dst = \p exp`. * * The TernaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) { auto &tmp1 = declare_temp_tile(dst.tile_info()); auto &tmp2 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.first); op_assign(tmp2, exp.third); TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, exp.second, tmp2); } /** Represents the assignment: `\p dst = \p exp`. * * The TernaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) { auto &tmp1 = declare_temp_tile(dst.tile_info()); auto &tmp2 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.second); op_assign(tmp2, exp.third); TWriter::op_ternary_elementwise_function(dst, exp.opcode, exp.first, tmp1, tmp2); } /** Represents the assignment: `\p dst = \p exp`. * * The TernaryElementwiseFunction is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] dst The tile which is assigned to. * @param[in] exp The TernaryElementwiseFunction representing the expression to be evaluated and assigned. */ template void op_assign(const TileOperand &dst, const TernaryElementwiseFunction &exp) { auto &tmp1 = declare_temp_tile(dst.tile_info(), dst.tile_info(), dst.tile_info()); auto &tmp2 = declare_temp_tile(dst.tile_info()); auto &tmp3 = declare_temp_tile(dst.tile_info()); op_assign(tmp1, exp.first); op_assign(tmp2, exp.second); op_assign(tmp3, exp.third); TWriter::op_ternary_elementwise_function(dst, exp.opcode, tmp1, tmp2, tmp3); } // ================================================== // Assignments // ================================================== /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`. * * The Assignment is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @param[in] exp The Assignment representing the expression to be evaluated. */ void op_assign(const Assignment &exp) { if (exp.opcode == AssignmentOp::Increment) { TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Add, exp.rhs); } else if (exp.opcode == AssignmentOp::Decrement) { TWriter::op_binary_expression(exp.lhs, exp.lhs, BinaryOp::Sub, exp.rhs); } } /** Represents the assignment: `\p lhs += \p rhs` or `\p lhs -= \p rhs`. * * The Assignment is unpacked and its components are forwarded to * the underlying KernelWriter's implementation. * * @tparam TRight The type of the RHS of the assignment. * @param[in] exp The Assignment representing the expression to be evaluated. */ template void op_assign(const Assignment &exp) { auto &tmp1 = declare_temp_tile(exp.lhs.tile_info()); op_assign(tmp1, exp.rhs); op_assign(Assignment{exp.lhs, tmp1, exp.opcode}); } private: unsigned int temp_var_counter = 0; /** Return the current counter value, then increment it. * * @return The current counter value. */ int next_ctr() { return temp_var_counter++; } /** Gets the next temporary variable counter value, * and returns a suitable temporary variable name. * * @return A temporary variable name. */ std::string next_tmp_var_name() { return "tmp_" + std::to_string(next_ctr()); } /** Returns the argument. * * Used for recursion with the variadic function version of this function. * * @param[in] arg The TileInfo to return. * @return The \p arg. */ TileInfo get_largest_size(const TileInfo &arg) { return arg; } /** Returns a TileInfo object where the size in each dimension (width, height) is the largest * of either TileInfo argument in the corresponding dimension. * * @tparam TOps Must be of TileInfo type. * @param[in] first A TileInfo object. * @param[in] second A TileInfo object. * @param[in] ops A number of TileInfo objects. * @return A TileInfo object which represents the largest shape in each dimension across the arguments. */ template ::value>> TileInfo get_largest_size(const TileInfo &first, const TileInfo &second, const TOps &...ops) { TileInfo largest = {first.data_type(), std::max(first.width(), second.width()), std::max(first.height(), second.height())}; return get_largest_size(largest, ops...); } /** Helper function to define a suitable TileOperand with appropriate TileInfo * such that broadcasting is taken into account, based on the arguments provided. * * @tparam TArgs Must be of TileInfo type. * @param[in] args A number of TileInfo which determine the shape of the TileOperand to declare. * @return A newly created TileOperand. */ template ::value>> TileOperand &declare_temp_tile(const TArgs &...args) { return TWriter::declare_tile(next_tmp_var_name().c_str(), get_largest_size(args...)); } }; } // namespace ckw #endif // CKW_INCLUDE_CKW_KERNELWRITERHELPER_H