aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-10-09 15:13:12 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:55:45 +0000
commitecae3a14746fc7f678735b1a82347bd03f9a397f (patch)
tree4c02d64cb9cbee8d270e2c498cc8ff1655bcdb36
parentdf3b5bb87296fdcde8ef88153f6365d693e80295 (diff)
downloadComputeLibrary-ecae3a14746fc7f678735b1a82347bd03f9a397f.tar.gz
COMPMID-1451: Enable dot kernels in NEGEMMAssembly functions
Change-Id: I9dd26b80025ea3a4c66f5f0bf41b7a98dd0d3aa4 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/152549 Tested-by: bsgcomp <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com>
-rw-r--r--src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp6
-rw-r--r--src/runtime/CPUUtils.cpp15
2 files changed, 11 insertions, 10 deletions
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp b/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp
index 3d42f8a51f..2c9cd320f0 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.cpp
@@ -37,7 +37,7 @@ template <typename To, typename Tr, bool use_dot>
void NEGEMMInterleavedMatrixMultiplyWrapperTemplate<To, Tr, use_dot>::configure(const ITensor *prepared_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker,
const BlockSizes &block_sizes, const INEGEMMWrapperKernel::Params &params, bool b_is_pretransposed, float alpha, float beta, unsigned int max_num_threads)
{
- using strategy = typename Kernel<To>::strategy;
+ using strategy = typename Kernel<To, use_dot>::strategy;
_prepared_a = prepared_a;
_transformed_b = transformed_b;
@@ -57,7 +57,7 @@ template <typename To, typename Tr, bool use_dot>
void NEGEMMInterleavedMatrixMultiplyWrapperTemplate<To, Tr, use_dot>::transform(const MatrixMultiplyWorkload &wl, const ThreadInfo &info, const Window &batch_window, const Coordinates &start_offset,
const Coordinates &end_offset)
{
- using strategy = typename Kernel<To>::strategy;
+ using strategy = typename Kernel<To, use_dot>::strategy;
strategy strat(info.cpu_info);
TensorAccessor<To> prepared_a(*_prepared_a);
@@ -98,7 +98,7 @@ void NEGEMMInterleavedMatrixMultiplyWrapperTemplate<To, Tr, use_dot>::transform(
template <typename To, typename Tr, bool use_dot>
void NEGEMMInterleavedMatrixMultiplyWrapperTemplate<To, Tr, use_dot>::create_workloads(std::vector<MatrixMultiplyWorkload> &workloads)
{
- using strategy = typename Kernel<To>::strategy;
+ using strategy = typename Kernel<To, use_dot>::strategy;
unsigned int offset_transformed_b = 0;
execute_window_loop(_block_walker, [&](const Coordinates & id)
diff --git a/src/runtime/CPUUtils.cpp b/src/runtime/CPUUtils.cpp
index f8feb33838..d81337c2b8 100644
--- a/src/runtime/CPUUtils.cpp
+++ b/src/runtime/CPUUtils.cpp
@@ -339,17 +339,18 @@ void get_cpu_configuration(CPUInfo &cpuinfo)
populate_models_cpuinfo(percpu);
}
int j(0);
- // Update dot product and FP16 support if all CPUs support these features:
- bool all_support_dot = true;
- bool all_support_fp16 = true;
+ // Update dot product and FP16 support if one of the CPUs support these features
+ // We assume that the system does not have mixed architectures
+ bool one_supports_dot = false;
+ bool one_supports_fp16 = false;
for(const auto &v : percpu)
{
- all_support_dot &= model_supports_dot(v);
- all_support_fp16 &= model_supports_fp16(v);
+ one_supports_dot |= model_supports_dot(v);
+ one_supports_fp16 |= model_supports_fp16(v);
cpuinfo.set_cpu_model(j++, v);
}
- cpuinfo.set_dotprod(all_support_dot || hwcaps_dot_support);
- cpuinfo.set_fp16(all_support_fp16 || hwcaps_fp16_support);
+ cpuinfo.set_dotprod(one_supports_dot || hwcaps_dot_support);
+ cpuinfo.set_fp16(one_supports_fp16 || hwcaps_fp16_support);
#else /* !defined(BARE_METAL) && (defined(__arm__) || defined(__aarch64__)) */
ARM_COMPUTE_UNUSED(cpuinfo);
#endif /* !defined(BARE_METAL) && (defined(__arm__) || defined(__aarch64__)) */