diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-11-02 01:37:17 +0000 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-11-12 15:59:25 +0000 |
commit | c0b6f76561580414f08633a804fc548ccad65659 (patch) | |
tree | 4d46b7f479de04f799e29095392948aeb370c029 /arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h | |
parent | 824061d9910ebb42cbe46b677c0b843db212c9a2 (diff) | |
download | ComputeLibrary-c0b6f76561580414f08633a804fc548ccad65659.tar.gz |
COMPMID-3776: Indirect GEMM
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I51a1b0f098bc3a8c408c50c92221e4df3061e12c
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4343
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h')
-rw-r--r-- | arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h | 58 |
1 files changed, 40 insertions, 18 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h index ac77acf69d..8f9498d0f5 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h +++ b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h @@ -32,6 +32,28 @@ namespace arm_compute { +/* Convolution method supported by the assembly gemm interface */ +enum class AsmConvMethod +{ + Im2Col, + Indirect, + Conv +}; + +struct AsmGemmInfo +{ + AsmConvMethod method{ AsmConvMethod::Im2Col }; + PadStrideInfo ps_info{}; + ActivationLayerInfo activation_info{}; + GEMMLowpOutputStageInfo output_stage{}; + bool negated_offsets{ true }; + bool reinterpret_input_as_3d{ false }; + bool depth_output_gemm3d{ false }; + int64_t padding_top{ 0 }; + int64_t padding_left{ 0 }; + float padding_value{ 0.f }; +}; + /** Assembly kernel glue */ class NEGEMMAssemblyDispatch : public IFunction { @@ -55,33 +77,28 @@ public: virtual ~IFallback() = default; }; -private: - /** Interface for the arm_gemm fallback */ - std::unique_ptr<IFallback> _arm_gemm; - MemoryGroup _memory_group; /**< Function memory group */ - IWeightsManager *_weights_manager; /**< Pointer to the weights manager */ public: /** If supported create a Compute Library function else fallback to the arm_gemm function. * - * @param[in] a Input tensor (Matrix A) - * @param[in] b Input tensor (Matrix B) - * @param[in] c Input tensor (Matrix C) used to pass the bias for quantized calculations - * @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. - * @param[in] gemm_info GEMM meta-data + * @param[in] a Input tensor (Matrix A) + * @param[in] b Input tensor (Matrix B) + * @param[in] c Input tensor (Matrix C) used to pass the bias for quantized calculations + * @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. + * @param[in] info GEMM meta-data */ - void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const GEMMInfo &gemm_info); + void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const AsmGemmInfo &info); /** Indicates whether or not this function can be used to process the given parameters. * - * @param[in] a Input tensor info (Matrix A) - * @param[in] b Input tensor info (Matrix B) - * @param[in] c Input tensor info (Matrix C) used to pass the bias for quantized calculations - * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. - * @param[in] gemm_info GEMM meta-data + * @param[in] a Input tensor info (Matrix A) + * @param[in] b Input tensor info (Matrix B) + * @param[in] c Input tensor info (Matrix C) used to pass the bias for quantized calculations + * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. + * @param[in] info GEMM meta-data * * @return a status. */ - static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const GEMMInfo &gemm_info); + static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); /** Checks if activation is supported by the gemm assembly dispatcher * * @param[in] activation Activation to check @@ -94,10 +111,15 @@ public: * @return True if the function is configured and ready to run */ bool is_configured() const; + // Inherited methods overridden: - /** Runs a preparation step, usually for pre-transposing matrix b */ void prepare() override; void run() override; + +private: + std::unique_ptr<IFallback> _arm_gemm; /** Interface for the arm_gemm fallback */ + MemoryGroup _memory_group; /**< Function memory group */ + IWeightsManager *_weights_manager; /**< Pointer to the weights manager */ }; } // namespace arm_compute #endif /* ARM_COMPUTE_NEGEMMASSEMBLYDISPATCH_H */ |