diff options
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.h')
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.h | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h index 3c25866f25..4ef108d430 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -53,6 +53,7 @@ struct AsmGemmInfo float padding_value{ 0.f }; bool fast_mode{ false }; bool fixed_format{ false }; + arm_gemm::WeightFormat weight_format{ arm_gemm::WeightFormat::UNSPECIFIED }; }; /** Assembly kernel glue */ @@ -73,6 +74,7 @@ public: virtual void prepare(ITensorPack &tensors) = 0; virtual experimental::MemoryRequirements workspace() const = 0; virtual bool is_configured() const = 0; + virtual bool isVarWeightsKernel() const = 0; virtual ~IFallback() = default; }; @@ -101,15 +103,14 @@ public: /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters. * - * @param[in] a Input tensor info (Matrix A) - * @param[in] b Input tensor info (Matrix B) - * @param[in] c Input tensor info (Matrix C) used to pass the bias for quantized calculations - * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. - * @param[in] info GEMM meta-data + * This method has the same use of @ref + * NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that + * the value of arm_gemm::WeightFormat need to be passed via the + * parameter info. * * @return a status. */ - static Status has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); + static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); /** Checks if activation is supported by the gemm assembly dispatcher * * @param[in] activation Activation to check @@ -122,6 +123,14 @@ public: * @return True if the function is configured and ready to run */ bool is_configured() const; + /** Indicates if the convolution executes in variable weights mode. + * + * Similar to @ref CpuGemm::isVarWeightsKernel + */ + bool isVarWeightsKernel() const + { + return _arm_gemm && _arm_gemm->isVarWeightsKernel(); + } // Inherited methods overridden: void prepare(ITensorPack &tensors) override; |