aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.h')
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h21
1 files changed, 15 insertions, 6 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 3c25866f25..4ef108d430 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -53,6 +53,7 @@ struct AsmGemmInfo
float padding_value{ 0.f };
bool fast_mode{ false };
bool fixed_format{ false };
+ arm_gemm::WeightFormat weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
};
/** Assembly kernel glue */
@@ -73,6 +74,7 @@ public:
virtual void prepare(ITensorPack &tensors) = 0;
virtual experimental::MemoryRequirements workspace() const = 0;
virtual bool is_configured() const = 0;
+ virtual bool isVarWeightsKernel() const = 0;
virtual ~IFallback() = default;
};
@@ -101,15 +103,14 @@ public:
/** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
*
- * @param[in] a Input tensor info (Matrix A)
- * @param[in] b Input tensor info (Matrix B)
- * @param[in] c Input tensor info (Matrix C) used to pass the bias for quantized calculations
- * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
- * @param[in] info GEMM meta-data
+ * This method has the same use of @ref
+ * NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that
+ * the value of arm_gemm::WeightFormat need to be passed via the
+ * parameter info.
*
* @return a status.
*/
- static Status has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
+ static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
/** Checks if activation is supported by the gemm assembly dispatcher
*
* @param[in] activation Activation to check
@@ -122,6 +123,14 @@ public:
* @return True if the function is configured and ready to run
*/
bool is_configured() const;
+ /** Indicates if the convolution executes in variable weights mode.
+ *
+ * Similar to @ref CpuGemm::isVarWeightsKernel
+ */
+ bool isVarWeightsKernel() const
+ {
+ return _arm_gemm && _arm_gemm->isVarWeightsKernel();
+ }
// Inherited methods overridden:
void prepare(ITensorPack &tensors) override;