aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2023-08-23 23:28:31 +0100
committerGunes Bayir <gunes.bayir@arm.com>2023-08-29 11:07:48 +0000
commit806b8e856911e6691ede6725c7e2a0e7e0dd6e95 (patch)
tree2430af238e9494a1b7012b05a3b49b2eef548cd2 /compute_kernel_writer/src
parentb7aefd71d07d56b001e795410700cae71a518eca (diff)
downloadComputeLibrary-806b8e856911e6691ede6725c7e2a0e7e0dd6e95.tar.gz
Add declare_constant_tile API function in CKW
Resolves: COMPMID-6535 Change-Id: I07d8aca96a0fcbd624f828b24513ee0500a14a74 Signed-off-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10200 Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'compute_kernel_writer/src')
-rw-r--r--compute_kernel_writer/src/KernelWriter.cpp10
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.cpp25
-rw-r--r--compute_kernel_writer/src/cl/CLKernelWriter.h15
-rw-r--r--compute_kernel_writer/src/types/ConstantData.cpp120
4 files changed, 165 insertions, 5 deletions
diff --git a/compute_kernel_writer/src/KernelWriter.cpp b/compute_kernel_writer/src/KernelWriter.cpp
index ce34a1c2d6..21a61d73bf 100644
--- a/compute_kernel_writer/src/KernelWriter.cpp
+++ b/compute_kernel_writer/src/KernelWriter.cpp
@@ -81,4 +81,14 @@ ITensor &KernelWriter::get_tensor(const TensorOperand &operand)
return operand._tensor;
}
+const std::vector<std::vector<std::string>> &KernelWriter::get_values(const ConstantData &data)
+{
+ return data.values();
+}
+
+DataType KernelWriter::get_data_type(const ConstantData &data)
+{
+ return data.data_type();
+}
+
} // namespace ckw
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.cpp b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
index 312162f498..79d0f985d0 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.cpp
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.cpp
@@ -28,6 +28,7 @@
#include "ckw/Kernel.h"
#include "ckw/TensorSampler.h"
#include "ckw/TileOperand.h"
+#include "ckw/types/DataType.h"
#include "ckw/types/MemoryOperation.h"
#include "ckw/types/TargetLanguage.h"
#include "src/ITensorComponent.h"
@@ -37,7 +38,6 @@
#include "src/cl/helpers/CLMemoryOpBufferHelper.h"
#include "src/cl/helpers/CLMemoryOpImage2dHelper.h"
#include "src/cl/helpers/ICLMemoryOpHelper.h"
-
#include "src/types/DataTypeHelpers.h"
#include <algorithm>
@@ -364,7 +364,7 @@ TileOperand CLKernelWriter::declare_tile(const std::string &name, const TileInfo
return e->name() == fullname;
})
== _tiles.end(),
- "Tile name must be unique.");
+ "There is already a tile with name: " + fullname);
auto tile = std::make_unique<CLTile>(fullname, tile_info);
@@ -381,17 +381,27 @@ TileOperand CLKernelWriter::declare_tile(const std::string &name, const TileInfo
return operand;
}
+TileOperand CLKernelWriter::declare_constant_tile(const ConstantData &data)
+{
+ auto tile = std::make_unique<CLTile>(get_values(data), get_data_type(data));
+ const TileOperand operand = create_tile_operand(*tile);
+ _constant_tiles.insert(std::move(tile));
+
+ return operand;
+}
+
void CLKernelWriter::op_write_raw_code(const std::string &raw_code)
{
append_code(raw_code);
}
-const CLTile &CLKernelWriter::to_cl_tile(const TileOperand &operand)
+const CLTile &CLKernelWriter::to_cl_tile(const TileOperand &operand) const
{
const auto &tile = get_tile(operand);
#ifdef COMPUTE_KERNEL_WRITER_ASSERTS_ENABLED
// Check if the tile is a CLTile created by this kernel writer.
+
{
bool found = false;
@@ -404,6 +414,15 @@ const CLTile &CLKernelWriter::to_cl_tile(const TileOperand &operand)
}
}
+ for(const auto &t : _constant_tiles)
+ {
+ if(&tile == t.get())
+ {
+ found = true;
+ break;
+ }
+ }
+
if(!found)
{
for(const auto &t : _tensors)
diff --git a/compute_kernel_writer/src/cl/CLKernelWriter.h b/compute_kernel_writer/src/cl/CLKernelWriter.h
index 2a6b79c691..d2c84f192e 100644
--- a/compute_kernel_writer/src/cl/CLKernelWriter.h
+++ b/compute_kernel_writer/src/cl/CLKernelWriter.h
@@ -29,17 +29,21 @@
#include <memory>
#include <set>
+#include <string>
#include <utility>
namespace ckw
{
+// Forward Declarations
class CLTile;
class CLTensorArgument;
+class ConstantData;
+class TensorOperand;
class TensorSampler;
class TileOperand;
-class TensorOperand;
+enum class DataType;
enum class MemoryOperation;
/** OpenCL kernel writer. */
@@ -96,6 +100,12 @@ public:
*/
TileOperand declare_tile(const std::string &name, const TileInfo &tile_info) override;
+ /** Declare a constant tile given a @ref:ConstantData object
+ *
+ * Similar to @ref KernelWriter::declare_constant_tile()
+ */
+ TileOperand declare_constant_tile(const ConstantData &data) override;
+
// =============================================================================================
// Memory Operations
// =============================================================================================
@@ -139,7 +149,7 @@ protected:
*
* This function performs appropriate check before doing type casting.
*/
- const CLTile &to_cl_tile(const TileOperand &operand);
+ const CLTile &to_cl_tile(const TileOperand &operand) const;
/** Append the specified code to the kernel body source code. */
template <typename T, typename... TArgs>
@@ -179,6 +189,7 @@ private:
std::set<std::unique_ptr<CLTensorArgument>> _tensors{};
std::set<std::unique_ptr<CLTile>> _tiles{};
+ std::set<std::unique_ptr<CLTile>> _constant_tiles{};
};
} // namespace ckw
diff --git a/compute_kernel_writer/src/types/ConstantData.cpp b/compute_kernel_writer/src/types/ConstantData.cpp
new file mode 100644
index 0000000000..d2155cf55a
--- /dev/null
+++ b/compute_kernel_writer/src/types/ConstantData.cpp
@@ -0,0 +1,120 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ckw/types/ConstantData.h"
+
+#include <limits>
+
+namespace ckw
+{
+namespace
+{
+ template<typename T>
+ inline typename std::enable_if<std::is_same<T, float>::value, std::string>::type to_str(T value)
+ {
+ std::stringstream ss;
+ ss << std::scientific << std::setprecision(std::numeric_limits<T>::max_digits10) << value;
+ return ss.str();
+ }
+
+ template<typename T>
+ inline typename std::enable_if<!std::is_same<T, float>::value && !std::is_same<T, bool>::value, std::string>::type to_str(T value)
+ {
+ return std::to_string(value);
+ }
+
+ template<typename T>
+ inline typename std::enable_if<std::is_same<T, bool>::value, std::string>::type to_str(T value)
+ {
+ return std::to_string((int) value);
+ }
+}
+
+template<typename T>
+ConstantData::ConstantData(std::initializer_list<std::initializer_list<T>> values, DataType data_type)
+ : _data_type(data_type)
+{
+ CKW_ASSERT(validate<T>(data_type));
+ CKW_ASSERT(values.size() > 0);
+
+ for(auto value_arr: values)
+ {
+ // Each row must have the same number of elements
+ CKW_ASSERT(value_arr.size() == (*values.begin()).size());
+
+ StringVector vec;
+ std::transform(value_arr.begin(), value_arr.end(),
+ std::back_inserter(vec),
+ [](T val) { return to_str(val); });
+
+ _values.push_back(std::move(vec));
+ }
+}
+
+template<typename T>
+bool ConstantData::validate(DataType data_type)
+{
+ switch(data_type)
+ {
+ case DataType::Fp32:
+ case DataType::Fp16:
+ return std::is_same<T, float>::value;
+ case DataType::Bool:
+ return std::is_same<T, bool>::value;
+ case DataType::Int32:
+ case DataType::Int16:
+ case DataType::Int8:
+ return std::is_same<T, int32_t>::value;
+ case DataType::Uint32:
+ case DataType::Uint16:
+ case DataType::Uint8:
+ return std::is_same<T, uint32_t>::value;
+ default:
+ CKW_THROW_MSG("Unknown data type!");
+ break;
+ }
+}
+
+// Necessary instantiations for compiler to recognize
+template ConstantData::ConstantData(std::initializer_list<std::initializer_list<int32_t>>, DataType);
+template ConstantData::ConstantData(std::initializer_list<std::initializer_list<uint32_t>>, DataType);
+template ConstantData::ConstantData(std::initializer_list<std::initializer_list<bool>>, DataType);
+template ConstantData::ConstantData(std::initializer_list<std::initializer_list<float>>, DataType);
+
+template bool ConstantData::validate<int32_t>(DataType);
+template bool ConstantData::validate<uint32_t>(DataType);
+template bool ConstantData::validate<bool>(DataType);
+template bool ConstantData::validate<float>(DataType);
+
+const std::vector<std::vector<std::string>>& ConstantData::values() const
+{
+ return _values;
+}
+
+DataType ConstantData::data_type() const
+{
+ return _data_type;
+}
+
+} // namespace ckw