aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemm.cpp
diff options
context:
space:
mode:
authorFrancesco Petrogalli <francesco.petrogalli@arm.com>2022-06-30 10:22:01 +0000
committerFrancesco Petrogalli <francesco.petrogalli@arm.com>2022-07-19 09:26:27 +0000
commit553f6953fe3bdfad53c11c25f305a16d79d83b24 (patch)
tree73642b948b79662096f593458c6138d2f7f48ec6 /src/cpu/operators/CpuGemm.cpp
parent99c46475daf277aa53e6747f9e41209f418fed33 (diff)
downloadComputeLibrary-553f6953fe3bdfad53c11c25f305a16d79d83b24.tar.gz
[ONCPUML-951] Variable weight support for Convolution.
API changes for NEGEMMConvolutionLayer and CpuGemmConv2d Built with: scons neon=1 opencl=0 os=linux arch=armv8.2-a multi_isa=1 \ build=native -j32 Werror=false validation_tests=1 build_dir=opt \ standalone=1 asserts=1 experimental_fixed_format_kernels=1 . Tested with: ./build/opt/tests/arm_compute_validation Hardware where the test executable was run: Neoverse N1 Test coverage: * NEGEMMConvolutionLayer, CpuGemmConv2d * NHWC (the only one supported by the fixed-format kernels) * F16, F32 * Shapes: RunSmall Change-Id: I4fd3e495a7cbf61210ea02d37440ba9652934e99 Signed-off-by: Francesco Petrogalli <francesco.petrogalli@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7632 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/CpuGemm.cpp')
-rw-r--r--src/cpu/operators/CpuGemm.cpp21
1 files changed, 18 insertions, 3 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index 61cd11ece0..f3fff608dc 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -51,6 +51,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.activation_info = info.activation_info();
asm_info.fast_mode = info.fast_math();
asm_info.fixed_format = info.fixed_format();
+ asm_info.weight_format = info.weight_format();
return asm_info;
}
@@ -177,7 +178,8 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
if(d->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != d->dimension(0));
+ // For fixed format we are expecting some kind of blocked format for B/RHS so the dimension won't necessarily match the result matrix any more.
+ ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.fixed_format() && b->dimension(0) != d->dimension(0));
if(gemm_info.depth_output_gemm3d() != 0)
{
if(gemm_info.reinterpret_input_as_3d())
@@ -277,7 +279,7 @@ void CpuGemm::run(ITensorPack &tensors)
auto c = tensors.get_const_tensor(ACL_SRC_2);
auto d = tensors.get_tensor(ACL_DST);
- if(_asm_glue->is_configured())
+ if(_asm_glue && _asm_glue->is_configured())
{
// Pass c to asm dispatch only if it's the bias tensor
ITensorPack asm_pack = tensors;
@@ -343,7 +345,7 @@ void CpuGemm::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
- if(_asm_glue->is_configured())
+ if(_asm_glue && _asm_glue->is_configured())
{
_asm_glue->prepare(tensors);
}
@@ -365,5 +367,18 @@ experimental::MemoryRequirements CpuGemm::workspace() const
{
return _aux_mem;
}
+
+Status CpuGemm::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+ const GEMMInfo &gemm_info)
+{
+ const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
+
+ return CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, asm_info);
+}
+
+bool CpuGemm::isVarWeightsKernel() const
+{
+ return _asm_glue && _asm_glue->isVarWeightsKernel();
+}
} // namespace cpu
} // namespace arm_compute