diff options
Diffstat (limited to 'arm_compute/graph/Types.h')
-rw-r--r-- | arm_compute/graph/Types.h | 166 |
1 files changed, 109 insertions, 57 deletions
diff --git a/arm_compute/graph/Types.h b/arm_compute/graph/Types.h index db5bbb8604..00d37a3354 100644 --- a/arm_compute/graph/Types.h +++ b/arm_compute/graph/Types.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,96 +24,148 @@ #ifndef __ARM_COMPUTE_GRAPH_TYPES_H__ #define __ARM_COMPUTE_GRAPH_TYPES_H__ -#include "arm_compute/core/ITensor.h" -#include "arm_compute/core/SubTensorInfo.h" -#include "arm_compute/core/TensorInfo.h" -#include "arm_compute/core/utils/logging/Macros.h" +#include "arm_compute/core/Error.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/strong_type/StrongType.h" +#include "arm_compute/core/utils/strong_type/StrongTypeAttributes.h" -/** Create a default core logger - * - * @note It will eventually create all default loggers in don't exist - */ -#define ARM_COMPUTE_CREATE_DEFAULT_GRAPH_LOGGER() \ - do \ - { \ - if(arm_compute::logging::LoggerRegistry::get().logger("GRAPH") == nullptr) \ - { \ - arm_compute::logging::LoggerRegistry::get().create_reserved_loggers(); \ - } \ - } while(false) - -#define ARM_COMPUTE_LOG_GRAPH(log_level, x) \ - ARM_COMPUTE_CREATE_DEFAULT_GRAPH_LOGGER(); \ - ARM_COMPUTE_LOG_STREAM("GRAPH", log_level, x) - -#define ARM_COMPUTE_LOG_GRAPH_INFO(x) \ - ARM_COMPUTE_CREATE_DEFAULT_GRAPH_LOGGER(); \ - ARM_COMPUTE_LOG_STREAM("GRAPH", arm_compute::logging::LogLevel::INFO, x) +#include <limits> +#include <string> namespace arm_compute { namespace graph { -using arm_compute::ActivationLayerInfo; +using arm_compute::Status; + using arm_compute::Coordinates; using arm_compute::DataType; -using arm_compute::DimensionRoundingType; -using arm_compute::ITensorInfo; +using arm_compute::TensorShape; +using arm_compute::Size2D; + +using arm_compute::ActivationLayerInfo; using arm_compute::NormType; using arm_compute::NormalizationLayerInfo; using arm_compute::PadStrideInfo; using arm_compute::PoolingLayerInfo; using arm_compute::PoolingType; -using arm_compute::SubTensorInfo; -using arm_compute::TensorInfo; -using arm_compute::TensorShape; -using arm_compute::WeightsInfo; +using arm_compute::DimensionRoundingType; + +/** TODO (geopin01): Make ids strongly typed */ +using TensorID = unsigned int; +using NodeID = unsigned int; +using EdgeID = unsigned int; +using Activation = arm_compute::ActivationLayerInfo::ActivationFunction; + +/**< GraphID strong type */ +using GraphID = strong_type::StrongType<unsigned int, struct graph_id_t, strong_type::Comparable>; +/* TODO (geopin01): Strong types for NodeID */ + +/**< Constant TensorID specifying an equivalent of null tensor */ +constexpr TensorID NullTensorID = std::numeric_limits<TensorID>::max(); +/**< Constant NodeID specifying an equivalent of null node */ +constexpr NodeID EmptyNodeID = std::numeric_limits<NodeID>::max(); +/**< Constant EdgeID specifying an equivalent of null edge */ +constexpr EdgeID EmptyEdgeID = std::numeric_limits<EdgeID>::max(); + +// Forward declarations +class TensorDescriptor; + +/** Graph configuration structure */ +struct GraphConfig +{ + bool use_function_memory_manager{ false }; /**< Use a memory manager to manage per-funcion auxilary memory */ + bool use_transition_memory_manager{ false }; /**< Use a memory manager to manager transition buffer memory */ + bool use_tuner{ false }; /**< Use a tuner in tunable backends */ + unsigned int num_threads{ 0 }; /**< Number of threads to use (thread capable backends), if 0 the backend will auto-initialize */ +}; -using arm_compute::logging::LogLevel; -using arm_compute::ConvertPolicy; +/**< Data layout format */ +enum class DataLayout +{ + NCHW, /** N(Batches), C(Channels), H(Height), W(Width) from slow to fast moving dimension */ + NHWC /** N(Batches), H(Height), W(Width), C(Channels) from slow to fast moving dimension */ +}; + +/**< Device target types */ +enum class Target +{ + UNSPECIFIED, /**< Unspecified Target */ + NEON, /**< NEON capable target device */ + CL, /**< OpenCL capable target device */ + GC, /**< GLES compute capable target device */ +}; -/**< Execution hint to the graph executor */ -enum class TargetHint +/** Supported Element-wise operations */ +enum class EltwiseOperation { - DONT_CARE, /**< Run node in any device */ - OPENCL, /**< Run node on an OpenCL capable device (GPU) */ - NEON /**< Run node on a NEON capable device */ + ADD, /**< Arithmetic addition */ + SUB, /**< Arithmetic subtraction */ + MUL /**< Arithmetic multiplication */ }; -/** Convolution method hint to the graph executor */ -enum class ConvolutionMethodHint +/** Supported Convolution layer methods */ +enum class ConvolutionMethod { - GEMM, /**< Convolution using GEMM */ - DIRECT, /**< Direct convolution */ - WINOGRAD /**< Winograd convolution */ + DEFAULT, /**< Default approach using internal heuristics */ + GEMM, /**< GEMM based convolution */ + DIRECT, /**< Deep direct convolution */ + WINOGRAD /**< Winograd based convolution */ }; -/** Supported layer operations */ -enum class OperationType +/** Supported Depthwise Convolution layer methods */ +enum class DepthwiseConvolutionMethod +{ + DEFAULT, /**< Default approach using internal heuristics */ + GEMV, /**< Generic GEMV based depthwise convolution */ + OPTIMIZED_3x3, /**< Optimized 3x3 direct depthwise convolution */ +}; + +/** Supported nodes */ +enum class NodeType { ActivationLayer, - ArithmeticAddition, BatchNormalizationLayer, ConvolutionLayer, - DepthConvertLayer, + DepthConcatenateLayer, DepthwiseConvolutionLayer, - DequantizationLayer, + EltwiseLayer, FlattenLayer, - FloorLayer, FullyConnectedLayer, - L2NormalizeLayer, NormalizationLayer, PoolingLayer, - QuantizationLayer, ReshapeLayer, - SoftmaxLayer + SoftmaxLayer, + SplitLayer, + + Input, + Output, + Const, +}; + +/** Backend Memory Manager affinity **/ +enum class MemoryManagerAffinity +{ + Buffer, /**< Affinity at buffer level */ + Offset /**< Affinity at offset level */ +}; + +/** NodeID-index struct + * + * Used to describe connections + */ +struct NodeIdxPair +{ + NodeID node_id; /**< Node ID */ + size_t index; /**< Index */ }; -/** Branch layer merging method */ -enum class BranchMergeMethod +/** Common node parameters */ +struct NodeParams { - DEPTH_CONCATENATE /**< Concatenate across depth */ + std::string name; /**< Node name */ + Target target; /**< Node target */ }; } // namespace graph } // namespace arm_compute -#endif /*__ARM_COMPUTE_GRAPH_TYPES_H__*/ +#endif /* __ARM_COMPUTE_GRAPH_TYPES_H__ */ |