/* * Copyright (c) 2022 Arm Limited. * * SPDX-License-Identifier: MIT * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to * deal in the Software without restriction, including without limitation the * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or * sell copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #ifdef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION #ifndef ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_OPERATORGRAPHIMPL #define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_OPERATORGRAPHIMPL #include "arm_compute/core/experimental/ClWorkload.h" #include "src/core/experimental/dynamic_fusion/WorkloadImpl/ITensorDescPack.h" #include "support/Cast.h" #include "support/DeepCopy.h" #include #include #include namespace arm_compute { namespace experimental { namespace dynamic_fusion { enum class OperatorComplexity { Complex = 0, Simple }; struct ClKernelGraph; struct OpTensorContent { public: using Id = DependencyGraph::Id; OpTensorContent() = default; OpTensorContent(Id id) : id{ id }, desc{} { } OpTensorContent(Id id, ITensorInfo *desc) : id{ id }, desc{ desc } { } ~OpTensorContent() = default; OpTensorContent(const OpTensorContent &) = default; OpTensorContent &operator=(const OpTensorContent &) = default; OpTensorContent(OpTensorContent &&) = default; OpTensorContent &operator=(OpTensorContent &&) = default; bool operator==(const OpTensorContent &other) const { return desc == other.desc; } const ITensorInfo *get_tensor_info() const { return desc; } ITensorInfo *get_tensor_info() { return desc; } Id id{}; ITensorInfo *desc{}; }; struct OperatorContent { public: using Id = DependencyGraph::Id; OperatorContent() = default; OperatorContent(const OperatorGraph::Implementation *graph, Id id, const ITensorDescPack &tensors) : _graph{ graph }, _id{ id }, _tensors{ tensors } { } OperatorContent(const OperatorContent &op) = default; OperatorContent &operator=(const OperatorContent &op) = default; OperatorContent(OperatorContent &&op) = default; OperatorContent &operator=(OperatorContent &&op) = default; virtual ~OperatorContent() = default; virtual OperatorComplexity complexity() const = 0; virtual bool operator==(const OperatorContent &other) const = 0; virtual Status translate(ClKernelGraph &kernel_graph) const = 0; protected: const OperatorGraph::Implementation *_graph {}; Id _id{}; ITensorDescPack _tensors{}; }; struct Conv2dContent : public OperatorContent { public: Conv2dContent() = default; Conv2dContent(const OperatorGraph::Implementation *graph, Id id, const Conv2dDescriptor &desc, const ITensorDescPack &tensors) : OperatorContent(graph, id, tensors), desc(desc), forced_method(), forced_method_enabled(false) { } // Temporary. Do not need to pass ConvolutionMethod Conv2dContent(const OperatorGraph::Implementation *graph, Id id, const Conv2dDescriptor &desc, const ITensorDescPack &tensors, ConvolutionMethod method) : OperatorContent(graph, id, tensors), desc(desc), forced_method(method), forced_method_enabled(true) { } ~Conv2dContent() = default; Conv2dContent(const Conv2dContent &) = default; Conv2dContent &operator=(const Conv2dContent &) = default; Conv2dContent(Conv2dContent &&) = default; Conv2dContent &operator=(Conv2dContent &&) = default; bool operator==(const OperatorContent &other) const override; OperatorComplexity complexity() const override { return OperatorComplexity::Complex; } void set_method(ConvolutionMethod method) { forced_method_enabled = true; forced_method = method; } Status translate(ClKernelGraph &kernel_graph) const override; /** Replicate heuristics of @ref ClConv2d::get_convolution_method(), except that non-supported data types and data layouts are removed from the heuristics * * @param src * @param weights * @param dst * @param conv2d_desc * @param gpu_target * @return ConvolutionMethod */ static ConvolutionMethod select_conv_method(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const Conv2dDescriptor &conv2d_desc, const GPUTarget gpu_target); Conv2dDescriptor desc{}; ConvolutionMethod forced_method{ ConvolutionMethod::GEMM_CONV2D }; bool forced_method_enabled{ false }; private: Status translate_direct_conv2d(ClKernelGraph &kernel_graph) const; }; class ElementwiseContent : public OperatorContent { public: ElementwiseContent() = default; ElementwiseContent(const OperatorGraph::Implementation *graph, Id id, const ElementwiseDescriptor &desc, const ITensorDescPack &tensors) : OperatorContent(graph, id, tensors), desc(desc) { } ~ElementwiseContent() = default; ElementwiseContent(const ElementwiseContent &) = default; ElementwiseContent &operator=(const ElementwiseContent &) = default; ElementwiseContent(ElementwiseContent &&) = default; ElementwiseContent &operator=(ElementwiseContent &&) = default; bool operator==(const OperatorContent &other) const override; OperatorComplexity complexity() const override { return OperatorComplexity::Simple; } Status translate(ClKernelGraph &kernel_graph) const override; private: ElementwiseDescriptor desc{}; }; class FloorContent : public OperatorContent { public: FloorContent() = default; FloorContent(const OperatorGraph::Implementation *graph, Id id, const FloorDescriptor &desc, const ITensorDescPack &tensors) : OperatorContent(graph, id, tensors), desc(desc) { } ~FloorContent() = default; FloorContent(const FloorContent &) = default; FloorContent &operator=(const FloorContent &) = default; FloorContent(FloorContent &&) = default; FloorContent &operator=(FloorContent &&) = default; bool operator==(const OperatorContent &other) const override; OperatorComplexity complexity() const override { return OperatorComplexity::Simple; } Status translate(ClKernelGraph &kernel_graph) const override; private: FloorDescriptor desc{}; }; struct OperatorGraph::Implementation { public: template void add_node(Operator::Id id, Args &&... args) { operators[id] = utils::memory::make_deep_unique(this, id, std::forward(args)...); } template void add_tensor(OpTensor::Id id, Args &&... args) { tensors[id] = utils::memory::make_deep_unique(id, std::forward(args)...); } using Dependency = DependencyGraph; using OperatorMap = std::map>; using OpTensorMap = std::map>; Implementation() = default; ~Implementation() = default; friend bool operator==(const OperatorGraph::Implementation &graph0, const OperatorGraph::Implementation &graph1) { return graph0.graph == graph1.graph && graph0.operators == graph1.operators && graph0.tensors == graph1.tensors; } Dependency graph{}; OperatorMap operators{}; OpTensorMap tensors{}; Status status{}; }; std::vector traverse(const OperatorGraph::Implementation &graph); std::vector traverse(OperatorGraph::Implementation &graph); Status translate(ClKernelGraph &kernel_graph, const OperatorGraph::Implementation &op_graph); } // namespace dynamic_fusion } // namespace experimental } // namespace arm_compute #endif //ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_OPERATORGRAPHIMPL #endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */