aboutsummaryrefslogtreecommitdiff
path: root/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.h')
-rw-r--r--src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.h229
1 files changed, 229 insertions, 0 deletions
diff --git a/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.h b/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.h
new file mode 100644
index 0000000000..c33e189797
--- /dev/null
+++ b/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.h
@@ -0,0 +1,229 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION
+#error "This experimental feature must be enabled with -DENABLE_EXPERIMENTAL_DYNAMIC_FUSION"
+#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */
+#ifndef ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_OPERATORGRAPHIMPL
+#define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_OPERATORGRAPHIMPL
+
+#include "arm_compute/core/experimental/ClWorkload.h"
+#include "src/core/experimental/dynamic_fusion/WorkloadImpl/ITensorDescPack.h"
+
+#include "support/Cast.h"
+#include "support/DeepCopy.h"
+
+#include <map>
+#include <tuple>
+#include <type_traits>
+
+namespace arm_compute
+{
+namespace experimental
+{
+namespace dynamic_fusion
+{
+enum class OperatorComplexity
+{
+ Complex = 0,
+ Simple
+};
+
+struct ClKernelGraph;
+struct OpTensorContent
+{
+public:
+ using Id = DependencyGraph::Id;
+ OpTensorContent() = default;
+ OpTensorContent(Id id)
+ : id{ id }, desc{}
+ {
+ }
+ OpTensorContent(Id id, ITensorInfo *desc)
+ : id{ id }, desc{ desc }
+ {
+ }
+ ~OpTensorContent() = default;
+ OpTensorContent(const OpTensorContent &) = default;
+ OpTensorContent &operator=(const OpTensorContent &) = default;
+ OpTensorContent(OpTensorContent &&) = default;
+ OpTensorContent &operator=(OpTensorContent &&) = default;
+ bool operator==(const OpTensorContent &other) const
+ {
+ return desc == other.desc;
+ }
+
+ const ITensorInfo *get_tensor_info() const
+ {
+ return desc;
+ }
+ ITensorInfo *get_tensor_info()
+ {
+ return desc;
+ }
+
+ Id id{};
+ ITensorInfo *desc{};
+};
+
+struct OperatorContent
+{
+public:
+ using Id = DependencyGraph::Id;
+ OperatorContent() = default;
+ OperatorContent(const OperatorGraph::Implementation *graph, Id id, const ITensorDescPack<OpTensorContent> &tensors)
+ : _graph{ graph }, _id{ id }, _tensors{ tensors }
+ {
+ }
+ OperatorContent(const OperatorContent &op) = default;
+ OperatorContent &operator=(const OperatorContent &op) = default;
+ OperatorContent(OperatorContent &&op) = default;
+ OperatorContent &operator=(OperatorContent &&op) = default;
+ virtual ~OperatorContent() = default;
+ virtual OperatorComplexity complexity() const = 0;
+ virtual bool operator==(const OperatorContent &other) const = 0;
+ virtual Status translate(ClKernelGraph &kernel_graph) const = 0;
+
+protected:
+ const OperatorGraph::Implementation *_graph {};
+ Id _id{};
+ ITensorDescPack<OpTensorContent> _tensors{};
+};
+
+struct Conv2dContent : public OperatorContent
+{
+public:
+ Conv2dContent() = default;
+ Conv2dContent(const OperatorGraph::Implementation *graph, Id id, const Conv2dDescriptor &desc, const ITensorDescPack<OpTensorContent> &tensors)
+ : OperatorContent(graph, id, tensors), desc(desc), forced_method(), forced_method_enabled(false)
+ {
+ }
+ // Temporary. Do not need to pass ConvolutionMethod
+ Conv2dContent(const OperatorGraph::Implementation *graph, Id id, const Conv2dDescriptor &desc, const ITensorDescPack<OpTensorContent> &tensors, ConvolutionMethod method)
+ : OperatorContent(graph, id, tensors), desc(desc), forced_method(method), forced_method_enabled(true)
+ {
+ }
+ ~Conv2dContent() = default;
+ Conv2dContent(const Conv2dContent &) = default;
+ Conv2dContent &operator=(const Conv2dContent &) = default;
+ Conv2dContent(Conv2dContent &&) = default;
+ Conv2dContent &operator=(Conv2dContent &&) = default;
+ bool operator==(const OperatorContent &other) const override;
+ OperatorComplexity complexity() const override
+ {
+ return OperatorComplexity::Complex;
+ }
+ void set_method(ConvolutionMethod method)
+ {
+ forced_method_enabled = true;
+ forced_method = method;
+ }
+
+ Status translate(ClKernelGraph &kernel_graph) const override;
+ /** Replicate heuristics of @ref ClConv2d::get_convolution_method(), except that non-supported data types and data layouts are removed from the heuristics
+ *
+ * @param src
+ * @param weights
+ * @param dst
+ * @param conv2d_desc
+ * @param gpu_target
+ * @return ConvolutionMethod
+ */
+ static ConvolutionMethod select_conv_method(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const Conv2dDescriptor &conv2d_desc, const GPUTarget gpu_target);
+
+ Conv2dDescriptor desc{};
+ ConvolutionMethod forced_method{ ConvolutionMethod::GEMM_CONV2D };
+ bool forced_method_enabled{ false };
+
+private:
+ Status translate_direct_conv2d(ClKernelGraph &kernel_graph) const;
+};
+
+class AddContent : public OperatorContent
+{
+public:
+ AddContent() = default;
+ AddContent(const OperatorGraph::Implementation *graph, Id id, const AddDescriptor &desc, const ITensorDescPack<OpTensorContent> &tensors)
+ : OperatorContent(graph, id, tensors), desc(desc)
+ {
+ }
+ ~AddContent() = default;
+ AddContent(const AddContent &) = default;
+ AddContent &operator=(const AddContent &) = default;
+ AddContent(AddContent &&) = default;
+ AddContent &operator=(AddContent &&) = default;
+ bool operator==(const OperatorContent &other) const override;
+ OperatorComplexity complexity() const override
+ {
+ return OperatorComplexity::Simple;
+ }
+ Status translate(ClKernelGraph &kernel_graph) const override;
+
+private:
+ AddDescriptor desc{};
+};
+
+struct OperatorGraph::Implementation
+{
+public:
+ template <typename ContentT, typename... Args>
+ void add_node(Operator::Id id, Args &&... args)
+ {
+ operators[id] = utils::memory::make_deep_unique<OperatorContent, ContentT>(this, id, std::forward<Args>(args)...);
+ }
+
+ template <typename... Args>
+ void add_tensor(OpTensor::Id id, Args &&... args)
+ {
+ tensors[id] = utils::memory::make_deep_unique<OpTensorContent, OpTensorContent>(id, std::forward<Args>(args)...);
+ }
+
+ using Dependency = DependencyGraph;
+ using OperatorMap = std::map<Operator::Id, utils::memory::deep_unique_ptr<OperatorContent>>;
+ using OpTensorMap = std::map<OpTensor::Id, utils::memory::deep_unique_ptr<OpTensorContent>>;
+
+ Implementation() = default;
+ ~Implementation() = default;
+
+ friend bool operator==(const OperatorGraph::Implementation &graph0, const OperatorGraph::Implementation &graph1)
+ {
+ return graph0.graph == graph1.graph && graph0.operators == graph1.operators && graph0.tensors == graph1.tensors;
+ }
+
+ Dependency graph{};
+ OperatorMap operators{};
+ OpTensorMap tensors{};
+ Status status{};
+};
+
+std::vector<const OperatorContent *> traverse(const OperatorGraph::Implementation &graph);
+
+std::vector<OperatorContent *> traverse(OperatorGraph::Implementation &graph);
+
+Status translate(ClKernelGraph &kernel_graph, const OperatorGraph::Implementation &op_graph);
+
+} // namespace dynamic_fusion
+} // namespace experimental
+} // namespace arm_compute
+
+#endif //ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_OPERATORGRAPHIMPL \ No newline at end of file