diff options
Diffstat (limited to 'compute_kernel_writer/src/cl/CLTensorArgument.cpp')
-rw-r--r-- | compute_kernel_writer/src/cl/CLTensorArgument.cpp | 150 |
1 files changed, 57 insertions, 93 deletions
diff --git a/compute_kernel_writer/src/cl/CLTensorArgument.cpp b/compute_kernel_writer/src/cl/CLTensorArgument.cpp index ed1c5bd687..7d4dc958df 100644 --- a/compute_kernel_writer/src/cl/CLTensorArgument.cpp +++ b/compute_kernel_writer/src/cl/CLTensorArgument.cpp @@ -24,7 +24,10 @@ #include "src/cl/CLTensorArgument.h" #include "ckw/Error.h" +#include "src/ITensorArgument.h" +#include "src/ITensorComponent.h" #include "src/cl/CLHelpers.h" +#include "src/cl/CLTensorComponent.h" #include "src/types/TensorComponentType.h" #include <algorithm> @@ -39,8 +42,25 @@ CLTensorArgument::CLTensorArgument(const std::string &name, const TensorInfo &in _info = info; } -TileVariable CLTensorArgument::component(TensorComponentType x) +CLTensorArgument::~CLTensorArgument() = default; + +CLTensorComponent &CLTensorArgument::cl_component(TensorComponentType x) { + // Return the component if it has already been created. + { + const auto it = std::find_if( + _components_used.begin(), _components_used.end(), + [=](const std::unique_ptr<CLTensorComponent> &item) + { + return item->component_type() == x; + }); + + if(it != _components_used.end()) + { + return **it; + } + } + if(_return_dims_by_value) { uint32_t component_type = static_cast<uint32_t>(x); @@ -100,42 +120,47 @@ TileVariable CLTensorArgument::component(TensorComponentType x) if(idx != kDynamicTensorDimensionValue) { - TileVariable t; - t.str = std::to_string(idx); - t.desc.dt = DataType::Uint32; - t.desc.len = 1; - return t; + _components_used.emplace_back(std::make_unique<CLTensorComponent>(*this, x, idx)); + + return *_components_used.back(); } } } - auto it = std::find(_components_used.begin(), _components_used.end(), x); + _components_used.emplace_back(std::make_unique<CLTensorComponent>(*this, x)); - // Add to the list of used components if not present yet - if(it == _components_used.end()) - { - _components_used.push_back(x); - } + return *_components_used.back(); +} - TileVariable t; - t.str = create_component_name(x); - t.desc.dt = DataType::Int32; - t.desc.len = 1; - return t; +ITile &CLTensorArgument::component(TensorComponentType x) +{ + return cl_component(x); } -TensorStorageVariable CLTensorArgument::storage(TensorStorageType x) +TensorStorageVariable &CLTensorArgument::storage(TensorStorageType x) { - if(std::find(_storages_used.begin(), _storages_used.end(), x) == _storages_used.end()) + // Return the storage if it has already been created. { - _storages_used.push_back(x); + const auto it = std::find_if( + _storages_used.begin(), _storages_used.end(), + [=](const TensorStorageVariable &item) + { + return item.type == x; + }); + + if(it != _storages_used.end()) + { + return *it; + } } TensorStorageVariable t; t.val = create_storage_name(x); - t.type = cl_get_variable_storagetype_as_string(x); + t.type = x; + + _storages_used.emplace_back(t); - return t; + return _storages_used.back(); } std::string CLTensorArgument::create_storage_name(TensorStorageType x) const @@ -159,87 +184,26 @@ std::string CLTensorArgument::create_storage_name(TensorStorageType x) const return var_name; } -std::string CLTensorArgument::create_component_name(TensorComponentType x) const -{ - std::string var_name = _basename; - - switch(x) - { - case TensorComponentType::OffsetFirstElement: - var_name += "_offset_first_element"; - break; - case TensorComponentType::Stride0: - var_name += "_stride0"; - break; - case TensorComponentType::Stride1: - var_name += "_stride1"; - break; - case TensorComponentType::Stride2: - var_name += "_stride2"; - break; - case TensorComponentType::Stride3: - var_name += "_stride3"; - break; - case TensorComponentType::Stride4: - var_name += "_stride4"; - break; - case TensorComponentType::Dim0: - var_name += "_dim0"; - break; - case TensorComponentType::Dim1: - var_name += "_dim1"; - break; - case TensorComponentType::Dim2: - var_name += "_dim2"; - break; - case TensorComponentType::Dim3: - var_name += "_dim3"; - break; - case TensorComponentType::Dim4: - var_name += "_dim4"; - break; - case TensorComponentType::Dim1xDim2: - var_name += "_dim1xdim2"; - break; - case TensorComponentType::Dim2xDim3: - var_name += "_dim2xdim3"; - break; - case TensorComponentType::Dim1xDim2xDim3: - var_name += "_dim1xdim2xdim3"; - break; - default: - COMPUTE_KERNEL_WRITER_ERROR_ON_MSG("Unsupported tensor component"); - return ""; - } - - return var_name; -} - std::vector<TensorStorageVariable> CLTensorArgument::storages() const { std::vector<TensorStorageVariable> storages; - for(auto &val : _storages_used) - { - TensorStorageVariable t; - t.val = create_storage_name(val); - t.type = cl_get_variable_storagetype_as_string(val); - storages.push_back(t); - } + storages.reserve(_storages_used.size()); + + std::copy(_storages_used.begin(), _storages_used.end(), std::back_inserter(storages)); return storages; } -std::vector<TileVariable> CLTensorArgument::components() const +std::vector<const ITensorComponent *> CLTensorArgument::components() const { - std::vector<TileVariable> components; + std::vector<const ITensorComponent *> components; - for(auto &val : _components_used) + for(const auto &component : _components_used) { - TileVariable t; - t.str = create_component_name(val); - t.desc.dt = DataType::Int32; - t.desc.len = 1; - components.push_back(t); + if(component->is_assignable()) + { + components.push_back(component.get()); + } } return components; |