diff options
Diffstat (limited to 'compute_kernel_writer/include')
-rw-r--r-- | compute_kernel_writer/include/ckw/Error.h | 7 | ||||
-rw-r--r-- | compute_kernel_writer/include/ckw/KernelWriter.h | 62 | ||||
-rw-r--r-- | compute_kernel_writer/include/ckw/types/ConvertPolicy.h | 41 | ||||
-rw-r--r-- | compute_kernel_writer/include/ckw/types/Operators.h | 57 |
4 files changed, 153 insertions, 14 deletions
diff --git a/compute_kernel_writer/include/ckw/Error.h b/compute_kernel_writer/include/ckw/Error.h index eaf3f10c05..7da9544b9e 100644 --- a/compute_kernel_writer/include/ckw/Error.h +++ b/compute_kernel_writer/include/ckw/Error.h @@ -113,6 +113,13 @@ inline void ignore_unused(T &&...) */ #define CKW_ASSERT(cond) CKW_ASSERT_MSG(cond, #cond) +/** If the precondition is met but the condition is not met, throw an std::runtime_error if assertion is enabled. + * + * @param[in] precond The precondition that triggers the check. + * @param[in] cond The condition that is expected to be true if precondition is true. + */ +#define CKW_ASSERT_IF(precond, cond) CKW_ASSERT(!(precond) || (cond)) + /** Throw an std::runtime_error with the specified message if assertion is enabled. * * @param[in] msg The error message when the condition is not met. diff --git a/compute_kernel_writer/include/ckw/KernelWriter.h b/compute_kernel_writer/include/ckw/KernelWriter.h index f77798e2ab..7eb6d2894a 100644 --- a/compute_kernel_writer/include/ckw/KernelWriter.h +++ b/compute_kernel_writer/include/ckw/KernelWriter.h @@ -27,7 +27,10 @@ #include "ckw/TensorOperand.h" #include "ckw/TileOperand.h" +#include "ckw/types/ConvertPolicy.h" +#include "ckw/types/Operators.h" +#include <functional> #include <memory> #include <string> @@ -76,6 +79,33 @@ public: virtual ~KernelWriter(); // ============================================================================================= + // Data processing + // ============================================================================================= + + /** Write assignment statement: `<dst> = <src>;`. + * + * @param[in] dst The destination tile. + * @param[in] src The source tile. + */ + virtual void op_assign(const TileOperand &dst, const TileOperand &src) = 0; + + /** Write the cast statement: `<dst> = convert_<dst.type><policy>(<src>);`. + * + * @param[in] dst The destination tile. + * @param[in] src The source tile. + * @param[in] policy The policy governing the behavior of the cast. + */ + virtual void op_cast(const TileOperand &dst, const TileOperand &src, ConvertPolicy policy) = 0; + + /** Write the unary expression statement: `<dst> = <op> <src>;`. + * + * @param[in] dst The destination tile. + * @param[in] src The source tile. + * @param[in] op The unary operator. + */ + virtual void op_unary(const TileOperand &dst, const TileOperand &src, UnaryOp op) = 0; + + // ============================================================================================= // Misc // ============================================================================================= @@ -87,7 +117,16 @@ public: * * @param[in] text The comment to be written. */ - virtual void comment(const std::string &text) = 0; + virtual void op_comment(const std::string &text) = 0; + + /** Write the given raw code to kernel source code + * It's used to address the cases where the user needs to + * explicitly add a code where it's not (yet) supported by + * the kernel writer utility calls. + * + * @param[in] raw_code raw code to write as string + */ + virtual void op_write_raw_code(const std::string &raw_code) = 0; // ============================================================================================= // Code generation @@ -121,15 +160,6 @@ public: */ virtual TileOperand declare_tile(const std::string &name, const TileInfo &tile_info) = 0; - /** Write the given raw code to kernel source code - * It's used to address the cases where the user needs to - * explicitly add a code where it's not (yet) supported by - * the kernel writer utility calls. - * - * @param[in] raw_code raw code to write as string - */ - virtual void op_write_raw_code(const std::string &raw_code) = 0; - /** Load the data from the tensor memory to the tile using the sampling information. * * @param[in] tile_op The tile to be loaded. @@ -140,7 +170,8 @@ public: * @param[in] z z-coordinate * @param[in] batch batch offset */ - virtual void op_load(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, + virtual void op_load( + const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) = 0; /** Load the data from the tensor memory to the tile in a dilated way using the sampling information. @@ -150,7 +181,8 @@ public: * @param[in] dilation_x Dilation while reading in x-dimension * @param[in] dilation_y Dilation while reading in y-dimension */ - virtual void op_load_dilated(const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, + virtual void op_load_dilated( + const TileOperand &tile_op, const TensorOperand &tensor_op, TensorSampler &sampler, const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, const TileOperand &dilation_x, const TileOperand &dilation_y) = 0; @@ -158,14 +190,16 @@ public: * * Similar to @ref KernelWriter::op_load() */ - virtual void op_store(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, + virtual void op_store( + const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch) = 0; /** Store the data to the tensor memory from the tile in a dilated way using the sampling information. * * Similar to @ref KernelWriter::op_load_dilated() */ - virtual void op_store_dilated(const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, + virtual void op_store_dilated( + const TensorOperand &tensor_op, const TileOperand &tile_op, TensorSampler &sampler, const TileOperand &x, const TileOperand &y, const TileOperand &z, const TileOperand &batch, const TileOperand &dilation_x, const TileOperand &dilation_y) = 0; diff --git a/compute_kernel_writer/include/ckw/types/ConvertPolicy.h b/compute_kernel_writer/include/ckw/types/ConvertPolicy.h new file mode 100644 index 0000000000..43a37ff118 --- /dev/null +++ b/compute_kernel_writer/include/ckw/types/ConvertPolicy.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef CKW_INCLUDE_CKW_TYPES_CONVERTPOLICY_H +#define CKW_INCLUDE_CKW_TYPES_CONVERTPOLICY_H + +#include <cstdint> + +namespace ckw +{ + +enum class ConvertPolicy : int32_t +{ + None = 0, // No policy specified. + Saturate = 1, // Saturated. +}; + +} // namespace ckw + +#endif // CKW_INCLUDE_CKW_TYPES_CONVERTPOLICY_H diff --git a/compute_kernel_writer/include/ckw/types/Operators.h b/compute_kernel_writer/include/ckw/types/Operators.h new file mode 100644 index 0000000000..ec2df08c46 --- /dev/null +++ b/compute_kernel_writer/include/ckw/types/Operators.h @@ -0,0 +1,57 @@ +/* +* Copyright (c) 2023 Arm Limited. +* +* SPDX-License-Identifier: MIT +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to +* deal in the Software without restriction, including without limitation the +* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +* sell copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in all +* copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ + +#ifndef CKW_INCLUDE_CKW_TYPES_OPERATORS_H +#define CKW_INCLUDE_CKW_TYPES_OPERATORS_H + +#include <cstdint> + +namespace ckw +{ + +/** Unary operators and functions. */ +enum class UnaryOp : int32_t +{ + LogicalNot = 0x0000, // ! + BitwiseNot = 0x0001, // ~ + + Exp = 0x0010, + Tanh = 0x0011, + Sqrt = 0x0012, + Erf = 0x0013, + Fabs = 0x0014, + Log = 0x0015, + Round = 0x0016, +}; + +/** Assignment operators. */ +enum class AssignmentOp : int32_t +{ + Increment = 0x0000, // += + Decrement = 0x0001, // -= +}; + +} // namespace ckw + +#endif // CKW_INCLUDE_CKW_TYPES_OPERATORS_H |