aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuReshapeKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuReshapeKernel.h')
-rw-r--r--src/cpu/kernels/CpuReshapeKernel.h24
1 files changed, 23 insertions, 1 deletions
diff --git a/src/cpu/kernels/CpuReshapeKernel.h b/src/cpu/kernels/CpuReshapeKernel.h
index 17302c6731..eddbbf7135 100644
--- a/src/cpu/kernels/CpuReshapeKernel.h
+++ b/src/cpu/kernels/CpuReshapeKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -58,6 +58,13 @@ public:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
+ /** Prepare the reshape kernel for execution (Only executed once) by calculating max or squashed window and selecting the _reshape_tensor_fn based on the presence of holes
+ *
+ * @param[in] tensors Pack of input and output tensors
+ *
+ */
+ void prepare(ITensorPack &tensors);
+
/** Return minimum workload size of the relevant kernel
*
* @param[in] platform The CPU platform used to create the context.
@@ -66,6 +73,21 @@ public:
* @return[out] small_network_mws Minimum workload size for requsted configuration.
*/
size_t get_mws(const CPUInfo &platform, size_t thread_count) const override;
+
+ /** Get the preferred dimension in which the scheduler splits the work into multiple jobs.
+ *
+ * @return The split dimension.
+ */
+ size_t get_split_dimension() const
+ {
+ return _split_dimension;
+ }
+
+private:
+ size_t _split_dimension{ Window::DimY };
+
+ std::function<void(const Window &window, const ITensor *src, ITensor *dst )> _reshape_tensor_fn{};
+
};
} // namespace kernels
} // namespace cpu