aboutsummaryrefslogtreecommitdiff
path: root/src/cpu
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu')
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp48
-rw-r--r--src/cpu/operators/CpuGemm.cpp6
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp22
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h1
4 files changed, 60 insertions, 17 deletions
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 200e04f9a8..9920b863d9 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -47,6 +47,46 @@ enum class GemmMethod
GEMM_HYBRID_QUANTIZED
};
+enum class WeightFormat
+{
+ UNSPECIFIED = 0x1,
+ ANY = 0x2,
+ OHWI = 0x100100,
+ OHWIo2 = 0x100200,
+ OHWIo4 = 0x100400,
+ OHWIo8 = 0x100800,
+ OHWIo16 = 0x101000,
+ OHWIo32 = 0x102000,
+ OHWIo64 = 0x104000,
+ OHWIo128 = 0x108000,
+ OHWIo4i2 = 0x200400,
+ OHWIo4i2_bf16 = 0x200410,
+ OHWIo8i2 = 0x200800,
+ OHWIo8i2_bf16 = 0x200810,
+ OHWIo16i2 = 0x201000,
+ OHWIo16i2_bf16 = 0x201010,
+ OHWIo32i2 = 0x202000,
+ OHWIo32i2_bf16 = 0x202010,
+ OHWIo64i2 = 0x204000,
+ OHWIo64i2_bf16 = 0x204010,
+ OHWIo4i4 = 0x400400,
+ OHWIo4i4_bf16 = 0x400410,
+ OHWIo8i4 = 0x400800,
+ OHWIo8i4_bf16 = 0x400810,
+ OHWIo16i4 = 0x401000,
+ OHWIo16i4_bf16 = 0x401010,
+ OHWIo32i4 = 0x402000,
+ OHWIo32i4_bf16 = 0x402010,
+ OHWIo64i4 = 0x404000,
+ OHWIo64i4_bf16 = 0x404010,
+ OHWIo2i8 = 0x800200,
+ OHWIo4i8 = 0x800400,
+ OHWIo8i8 = 0x800800,
+ OHWIo16i8 = 0x801000,
+ OHWIo32i8 = 0x802000,
+ OHWIo64i8 = 0x804000
+};
+
struct KernelDescription
{
GemmMethod method = GemmMethod::DEFAULT;
@@ -69,6 +109,7 @@ struct GemmConfig
std::string filter = "";
unsigned int inner_block_size = 0;
unsigned int outer_block_size = 0;
+ WeightFormat weight_format = WeightFormat::ANY;
GemmConfig(GemmMethod method)
: method(method)
@@ -111,15 +152,16 @@ public:
bool _indirect_input;
Activation _act;
int _maxthreads;
+ bool _fixed_format;
bool _fast_mode;
const GemmConfig *_cfg;
GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N,
unsigned int K, unsigned int Ksections, unsigned int nbatches,
unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads,
- bool fast_mode = false, const GemmConfig *cfg = nullptr)
- : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), _fast_mode(fast_mode),
- _cfg(cfg)
+ bool fixed_format = false, bool fast_mode = false, const GemmConfig *cfg = nullptr)
+ : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads),
+ _fixed_format(fixed_format), _fast_mode(fast_mode), _cfg(cfg)
{
}
};
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index 9c7ad92761..61cd11ece0 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -50,6 +50,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.depth_output_gemm3d = info.depth_output_gemm3d();
asm_info.activation_info = info.activation_info();
asm_info.fast_mode = info.fast_math();
+ asm_info.fixed_format = info.fixed_format();
return asm_info;
}
@@ -72,8 +73,7 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
_run_alpha_scale = alpha != 1.f;
_run_bias_addition = c != nullptr && gemm_info.reshape_b_only_on_first_run();
_run_addition = beta != 0 && c != nullptr && !gemm_info.reshape_b_only_on_first_run();
- _run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised
- && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info())));
+ _run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info())));
if(run_optimised)
{
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index a97e53d4af..787ea95372 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -156,8 +156,8 @@ public:
const std::vector<int32_t> &multipliers);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
bool is_configured() const override;
experimental::MemoryRequirements workspace() const override;
@@ -203,12 +203,12 @@ private:
/** Indirect buffer */
std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
- experimental::MemoryRequirements _aux_mem{ Count };
- bool _B_pretranspose_required{ false };
- bool _is_b_constant{ true };
- bool _is_c_constant{ true };
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{ Count };
+ bool _B_pretranspose_required{ false };
+ bool _is_b_constant{ true };
+ bool _is_c_constant{ true };
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -576,7 +576,7 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fast_mode);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
@@ -594,7 +594,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
const CPUInfo &ci = NEScheduler::get().cpu_info();
const unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fast_mode);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
@@ -644,7 +644,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fast_mode);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode);
switch(a->data_type())
{
case DataType::F32:
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 74359eee72..3c25866f25 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -52,6 +52,7 @@ struct AsmGemmInfo
int64_t padding_left{ 0 };
float padding_value{ 0.f };
bool fast_mode{ false };
+ bool fixed_format{ false };
};
/** Assembly kernel glue */