aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src/cl/CLTensorArgument.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/src/cl/CLTensorArgument.cpp')
-rw-r--r--compute_kernel_writer/src/cl/CLTensorArgument.cpp150
1 files changed, 57 insertions, 93 deletions
diff --git a/compute_kernel_writer/src/cl/CLTensorArgument.cpp b/compute_kernel_writer/src/cl/CLTensorArgument.cpp
index ed1c5bd687..7d4dc958df 100644
--- a/compute_kernel_writer/src/cl/CLTensorArgument.cpp
+++ b/compute_kernel_writer/src/cl/CLTensorArgument.cpp
@@ -24,7 +24,10 @@
#include "src/cl/CLTensorArgument.h"
#include "ckw/Error.h"
+#include "src/ITensorArgument.h"
+#include "src/ITensorComponent.h"
#include "src/cl/CLHelpers.h"
+#include "src/cl/CLTensorComponent.h"
#include "src/types/TensorComponentType.h"
#include <algorithm>
@@ -39,8 +42,25 @@ CLTensorArgument::CLTensorArgument(const std::string &name, const TensorInfo &in
_info = info;
}
-TileVariable CLTensorArgument::component(TensorComponentType x)
+CLTensorArgument::~CLTensorArgument() = default;
+
+CLTensorComponent &CLTensorArgument::cl_component(TensorComponentType x)
{
+ // Return the component if it has already been created.
+ {
+ const auto it = std::find_if(
+ _components_used.begin(), _components_used.end(),
+ [=](const std::unique_ptr<CLTensorComponent> &item)
+ {
+ return item->component_type() == x;
+ });
+
+ if(it != _components_used.end())
+ {
+ return **it;
+ }
+ }
+
if(_return_dims_by_value)
{
uint32_t component_type = static_cast<uint32_t>(x);
@@ -100,42 +120,47 @@ TileVariable CLTensorArgument::component(TensorComponentType x)
if(idx != kDynamicTensorDimensionValue)
{
- TileVariable t;
- t.str = std::to_string(idx);
- t.desc.dt = DataType::Uint32;
- t.desc.len = 1;
- return t;
+ _components_used.emplace_back(std::make_unique<CLTensorComponent>(*this, x, idx));
+
+ return *_components_used.back();
}
}
}
- auto it = std::find(_components_used.begin(), _components_used.end(), x);
+ _components_used.emplace_back(std::make_unique<CLTensorComponent>(*this, x));
- // Add to the list of used components if not present yet
- if(it == _components_used.end())
- {
- _components_used.push_back(x);
- }
+ return *_components_used.back();
+}
- TileVariable t;
- t.str = create_component_name(x);
- t.desc.dt = DataType::Int32;
- t.desc.len = 1;
- return t;
+ITile &CLTensorArgument::component(TensorComponentType x)
+{
+ return cl_component(x);
}
-TensorStorageVariable CLTensorArgument::storage(TensorStorageType x)
+TensorStorageVariable &CLTensorArgument::storage(TensorStorageType x)
{
- if(std::find(_storages_used.begin(), _storages_used.end(), x) == _storages_used.end())
+ // Return the storage if it has already been created.
{
- _storages_used.push_back(x);
+ const auto it = std::find_if(
+ _storages_used.begin(), _storages_used.end(),
+ [=](const TensorStorageVariable &item)
+ {
+ return item.type == x;
+ });
+
+ if(it != _storages_used.end())
+ {
+ return *it;
+ }
}
TensorStorageVariable t;
t.val = create_storage_name(x);
- t.type = cl_get_variable_storagetype_as_string(x);
+ t.type = x;
+
+ _storages_used.emplace_back(t);
- return t;
+ return _storages_used.back();
}
std::string CLTensorArgument::create_storage_name(TensorStorageType x) const
@@ -159,87 +184,26 @@ std::string CLTensorArgument::create_storage_name(TensorStorageType x) const
return var_name;
}
-std::string CLTensorArgument::create_component_name(TensorComponentType x) const
-{
- std::string var_name = _basename;
-
- switch(x)
- {
- case TensorComponentType::OffsetFirstElement:
- var_name += "_offset_first_element";
- break;
- case TensorComponentType::Stride0:
- var_name += "_stride0";
- break;
- case TensorComponentType::Stride1:
- var_name += "_stride1";
- break;
- case TensorComponentType::Stride2:
- var_name += "_stride2";
- break;
- case TensorComponentType::Stride3:
- var_name += "_stride3";
- break;
- case TensorComponentType::Stride4:
- var_name += "_stride4";
- break;
- case TensorComponentType::Dim0:
- var_name += "_dim0";
- break;
- case TensorComponentType::Dim1:
- var_name += "_dim1";
- break;
- case TensorComponentType::Dim2:
- var_name += "_dim2";
- break;
- case TensorComponentType::Dim3:
- var_name += "_dim3";
- break;
- case TensorComponentType::Dim4:
- var_name += "_dim4";
- break;
- case TensorComponentType::Dim1xDim2:
- var_name += "_dim1xdim2";
- break;
- case TensorComponentType::Dim2xDim3:
- var_name += "_dim2xdim3";
- break;
- case TensorComponentType::Dim1xDim2xDim3:
- var_name += "_dim1xdim2xdim3";
- break;
- default:
- COMPUTE_KERNEL_WRITER_ERROR_ON_MSG("Unsupported tensor component");
- return "";
- }
-
- return var_name;
-}
-
std::vector<TensorStorageVariable> CLTensorArgument::storages() const
{
std::vector<TensorStorageVariable> storages;
- for(auto &val : _storages_used)
- {
- TensorStorageVariable t;
- t.val = create_storage_name(val);
- t.type = cl_get_variable_storagetype_as_string(val);
- storages.push_back(t);
- }
+ storages.reserve(_storages_used.size());
+
+ std::copy(_storages_used.begin(), _storages_used.end(), std::back_inserter(storages));
return storages;
}
-std::vector<TileVariable> CLTensorArgument::components() const
+std::vector<const ITensorComponent *> CLTensorArgument::components() const
{
- std::vector<TileVariable> components;
+ std::vector<const ITensorComponent *> components;
- for(auto &val : _components_used)
+ for(const auto &component : _components_used)
{
- TileVariable t;
- t.str = create_component_name(val);
- t.desc.dt = DataType::Int32;
- t.desc.len = 1;
- components.push_back(t);
+ if(component->is_assignable())
+ {
+ components.push_back(component.get());
+ }
}
return components;