aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
diff options
context:
space:
mode:
authorAnthony Barbier <anthony.barbier@arm.com>2018-07-20 17:49:35 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commiteaefd002a5d6509dd5f12e98b538c99b33c2c1ee (patch)
tree18951e67cf2c0c0b91e88d9174d0c350890456a1 /arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
parentc8e84b5a3872eda6748d77dbaf8548ad99f4c0cd (diff)
downloadComputeLibrary-eaefd002a5d6509dd5f12e98b538c99b33c2c1ee.tar.gz
COMPMID-1419: Make NEGEMMAssemblyDispatch dynamically typed instead of templated
This makes it easier to integrate in GEMMLowpMatrixMultiplyCore Change-Id: Ibf80803f016a2e6a24d943ffafb50b48f04ec545 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140868 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h')
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h89
1 files changed, 30 insertions, 59 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
index 1c9ecb088e..382ef1caba 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
@@ -35,7 +35,6 @@
namespace arm_compute
{
/** Assembly kernel glue */
-template <typename TypeInput, typename TypeOutput>
class NEGEMMAssemblyDispatch : public IFunction
{
public:
@@ -43,12 +42,21 @@ public:
NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
/** Prevent instances of this class from being copy constructed */
- NEGEMMAssemblyDispatch(const NEGEMMAssemblyDispatch<TypeInput, TypeOutput> &) = delete;
+ NEGEMMAssemblyDispatch(const NEGEMMAssemblyDispatch &) = delete;
/** Prevent instances of this class from being copied */
- NEGEMMAssemblyDispatch<TypeInput, TypeOutput> &operator=(const NEGEMMAssemblyDispatch<TypeInput, TypeOutput> &) = delete;
- NEGEMMAssemblyDispatch(NEGEMMAssemblyDispatch<TypeInput, TypeOutput> &&) = default;
- NEGEMMAssemblyDispatch<TypeInput, TypeOutput> &operator=(NEGEMMAssemblyDispatch<TypeInput, TypeOutput> &&) = default;
- ~NEGEMMAssemblyDispatch() = default;
+ NEGEMMAssemblyDispatch &operator=(const NEGEMMAssemblyDispatch &) = delete;
+ NEGEMMAssemblyDispatch(NEGEMMAssemblyDispatch &&) = default;
+ NEGEMMAssemblyDispatch &operator=(NEGEMMAssemblyDispatch &&) = default;
+ ~NEGEMMAssemblyDispatch() = default;
+
+ class IFallback
+ {
+ public:
+ virtual void run() = 0;
+ virtual void prepare() = 0;
+ virtual bool is_configured() const = 0;
+ virtual ~IFallback() = default;
+ };
private:
/** ACL Function */
@@ -68,53 +76,9 @@ private:
*/
bool create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint);
- //Fallback: use arm_gemm's AssemblyGemm:
- class Fallback
- {
-#ifndef DOXYGEN_SKIP_THIS
- public:
- /** Configures the arrays pointers and strides in the assembly kernel and executes the assembly kernel.
- * The call to set_arrays is needed to deal with the input sizes containing batches (dims > 2)
- */
- void run();
- void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group);
- void prepare();
- bool is_configured() const;
-#endif /* DOXYGEN_SKIP_THIS */
-
- private:
- /** Allocate a workspace tensor.
- *
- * @param[in] workspace_size Size to allocate.
- * @param[in] memory_group Tensor memory group.
- * @param[in] alignment Workspace memory alignment.
- */
- void allocate_workspace(size_t workspace_size, MemoryGroup *memory_group, size_t alignment);
-
- /** Assembly Gemm kernel */
- std::unique_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
- /** Optimised NEON kernel */
- std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
- /** Input A */
- const ITensor *_a
- {
- nullptr
- };
- /** Input B */
- const ITensor *_b
- {
- nullptr
- };
- /** Output */
- ITensor *_d{ nullptr };
- /** GEMM workspace */
- Tensor _workspace{};
- /** Pre-transpose tensor */
- Tensor _pretranspose{};
- /** Prepared flag */
- bool _is_prepared{ false };
- } _arm_gemm; /**< Fallback in case ACL doesn't have a function */
- MemoryGroup _memory_group; /**< Function memory group */
+ /** Interface for the arm_gemm fallback */
+ std::unique_ptr<IFallback> _arm_gemm;
+ MemoryGroup _memory_group; /**< Function memory group */
public:
/** If supported create an ACL function else fallback to the arm_gemm function.
*
@@ -126,6 +90,19 @@ public:
* @param[in] pretranspose_hint Can the B tensor can be pretransposed (ie shared across invocations)?
*/
void configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint);
+
+ /** Indicates whether or not this function can be used to process the given parameters.
+ *
+ * @param[in] a Input tensor (Matrix A)
+ * @param[in] b Input tensor (Matrix B)
+ * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+ * @param[in] alpha Scalar multiplier to apply to AB matrix product.
+ * @param[in] beta Scalar multiplier to apply to input D matrix before adding product.
+ * @param[in] pretranspose_hint Can the B tensor can be pretransposed (ie shared across invocations)?
+ *
+ * @return a status.
+ */
+ static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint);
/** Was the function successfully configured ?
*
* @return True if the function is configured and ready to run
@@ -137,11 +114,5 @@ public:
void run() override;
};
-/** Float 32 assembly dispatch kernel */
-using NEGEMMAssemblyDispatchF32 = NEGEMMAssemblyDispatch<float, float>;
-/** Uint 8 to Uint 32 assembly dispatch kernel */
-using NEGEMMAssemblyDispatchU8U32 = NEGEMMAssemblyDispatch<uint8_t, uint32_t>;
-/** Int 8 to Int 32 assembly dispatch kernel */
-using NEGEMMAssemblyDispatchS8S32 = NEGEMMAssemblyDispatch<int8_t, int32_t>;
} // namespace arm_compute
#endif /* __ARM_COMPUTE_NEGEMMASSEMBLYDISPATCH_H__ */