diff options
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.h')
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.h | 55 |
1 files changed, 33 insertions, 22 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h index ceb7a3f775..5be39a54c0 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -25,6 +25,7 @@ #define ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H #include "arm_compute/function_info/ActivationLayerInfo.h" + #include "src/core/common/Macros.h" #include "src/cpu/ICpuOperator.h" @@ -42,20 +43,20 @@ enum class AsmConvMethod struct AsmGemmInfo { - AsmConvMethod method{ AsmConvMethod::Im2Col }; + AsmConvMethod method{AsmConvMethod::Im2Col}; PadStrideInfo ps_info{}; ActivationLayerInfo activation_info{}; GEMMLowpOutputStageInfo output_stage{}; - bool negated_offsets{ true }; - bool reinterpret_input_as_3d{ false }; - bool depth_output_gemm3d{ false }; - int64_t padding_top{ 0 }; - int64_t padding_left{ 0 }; - float padding_value{ 0.f }; - bool fast_mode{ false }; - bool fixed_format{ false }; - arm_compute::WeightFormat weight_format{ arm_compute::WeightFormat::UNSPECIFIED }; - bool reshape_b_only_on_first_run{ true }; + bool negated_offsets{true}; + bool reinterpret_input_as_3d{false}; + bool depth_output_gemm3d{false}; + int64_t padding_top{0}; + int64_t padding_left{0}; + float padding_value{0.f}; + bool fast_mode{false}; + bool fixed_format{false}; + arm_compute::WeightFormat weight_format{arm_compute::WeightFormat::UNSPECIFIED}; + bool reshape_b_only_on_first_run{true}; }; /** Assembly kernel glue */ @@ -72,12 +73,12 @@ public: class IFallback { public: - virtual void run(ITensorPack &tensors) = 0; - virtual void prepare(ITensorPack &tensors) = 0; - virtual experimental::MemoryRequirements workspace() const = 0; - virtual bool is_configured() const = 0; - virtual bool isVarWeightsKernel() const = 0; - virtual ~IFallback() = default; + virtual void run(ITensorPack &tensors) = 0; + virtual void prepare(ITensorPack &tensors) = 0; + virtual experimental::MemoryRequirements workspace() const = 0; + virtual bool is_configured() const = 0; + virtual bool isVarWeightsKernel() const = 0; + virtual ~IFallback() = default; }; public: @@ -121,7 +122,8 @@ public: * @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. * @param[in] info GEMM meta-data */ - void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info); + void configure( + const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info); /** Indicates whether or not this function can be used to process the given parameters. * @@ -133,7 +135,11 @@ public: * * @return a status. */ - static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); + static Status validate(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + const ITensorInfo *d, + const AsmGemmInfo &info); /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters. * @@ -144,7 +150,12 @@ public: * * @return a status. */ - static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); + static Status has_opt_impl(arm_compute::WeightFormat &weight_format, + const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + const ITensorInfo *d, + const AsmGemmInfo &info); /** Checks if activation is supported by the gemm assembly dispatcher * * @param[in] activation Activation to check @@ -167,8 +178,8 @@ public: } // Inherited methods overridden: - void prepare(ITensorPack &tensors) override; - void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; experimental::MemoryRequirements workspace() const override; private: |