aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion/sketch/gpu/GpuWorkloadContextImpl.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion/sketch/gpu/GpuWorkloadContextImpl.h')
-rw-r--r--src/dynamic_fusion/sketch/gpu/GpuWorkloadContextImpl.h30
1 files changed, 18 insertions, 12 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/GpuWorkloadContextImpl.h b/src/dynamic_fusion/sketch/gpu/GpuWorkloadContextImpl.h
index a857932791..c169476a70 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuWorkloadContextImpl.h
+++ b/src/dynamic_fusion/sketch/gpu/GpuWorkloadContextImpl.h
@@ -36,7 +36,6 @@ namespace experimental
{
namespace dynamic_fusion
{
-
/** Internal implementation of workload context. */
class GpuWorkloadContext::Impl
{
@@ -52,7 +51,7 @@ public:
Impl(Impl &) = default;
/** Assignment */
- Impl& operator=(Impl &) = default;
+ Impl &operator=(Impl &) = default;
/** Get target GPU language. */
GpuLanguage gpu_language() const;
@@ -69,27 +68,34 @@ public:
*/
void register_user_tensor(ITensorInfo &tensor_info);
- /** Set a new ID and register the auxiliary tensor info.
+ /** Create a virtual (see @ref MemoryType) tensor info and save it
*
- * @param[in, out] tensor_info The tensor info to be registered.
- * @param[in] mem_info The auxiliary tensor memory info.
+ * @return ITensorInfo* The created virtual tensor info object pointer
*/
- void register_aux_tensor(ITensorInfo &tensor_info, const AuxMemoryInfo &mem_info);
-
- /** Set a new ID and register the virtual tensor info.
+ ITensorInfo *create_virtual_tensor();
+ /** Create an auxiliary (see @ref MemoryType) tensor info and save it
*
- * @param[in, out] tensor_info The tensor info to be registered.
+ * @param[in] tensor_info @ref ITensorInfo to copy from
+ *
+ * @return ITensorInfo* The created auxiliary tensor info object pointer
*/
- void register_virtual_tensor(ITensorInfo &tensor_info);
+ ITensorInfo *create_auxiliary_tensor(const ITensorInfo &tensor_info);
+
+ /** Get tensor info created by this context, from id */
+ ITensorInfo *get_tensor_info(ITensorInfo::Id id);
+
+ /** Get tensor info created by this context, from id */
+ const ITensorInfo *get_tensor_info(ITensorInfo::Id id) const;
private:
ITensorInfo::Id next_tensor_id();
- GpuLanguage _gpu_language;
+ GpuLanguage _gpu_language;
CLCompileContext *_cl_compile_ctx;
- ITensorInfo::Id _next_tensor_id;
+ ITensorInfo::Id _next_tensor_id;
MemoryDescriptorMap _mem_map;
+ std::map<ITensorInfo::Id, std::unique_ptr<TensorInfo>> _managed_tensor_info;
};
} // namespace dynamic_fusion