diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp | 2 | ||||
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 9 |
2 files changed, 9 insertions, 2 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index 6e14a68438..470cee1557 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -695,7 +695,7 @@ public: #endif /* Make sure we've been set up correctly. */ - assert(_B_transposed); + assert(FixedFormat || _B_transposed); assert(_working_space); int8_t *working_space_bytes = reinterpret_cast<int8_t *>(_working_space); diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index df02d649f8..77da83070b 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -25,6 +25,7 @@ #include "arm_compute/runtime/NEON/NEScheduler.h" #include "src/core/CPP/Validate.h" +#include "src/core/NEON/kernels/arm_gemm/utils.hpp" #include "src/core/helpers/MemoryHelpers.h" #include "src/core/utils/AssemblyUtils.h" #include "src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h" @@ -507,14 +508,20 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) const TensorShape tensor_shape = tensor_info->tensor_shape(); const int tensor_height = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; const int tensor_width = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; - const int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; + int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; const int interleave_by = arm_compute::interleave_by(wf); + const int blocked_by = arm_compute::block_by(wf); // We need to find a new stride that is distance from the data for one // set of output channels to the next if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width) { // In this case dimensions that are packed are height, width and channel // so we need to stride it by interleave_by + if(tensor_channels % blocked_by != 0) + { + // We need to pad + tensor_channels = arm_gemm::iceildiv(tensor_channels, blocked_by) * blocked_by; + } ldb = interleave_by * tensor_height * tensor_width * tensor_channels; } else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width)) |