aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp55
1 files changed, 37 insertions, 18 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 45b3232423..df02d649f8 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -156,8 +156,8 @@ public:
const std::vector<int32_t> &multipliers);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
bool is_configured() const override;
experimental::MemoryRequirements workspace() const override;
bool isVarWeightsKernel() const override
@@ -210,12 +210,12 @@ private:
/** Indirect buffer */
std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
- experimental::MemoryRequirements _aux_mem{ Count };
- bool _B_pretranspose_required{ false };
- bool _is_b_constant{ true };
- bool _is_c_constant{ true };
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{ Count };
+ bool _B_pretranspose_required{ false };
+ bool _is_b_constant{ true };
+ bool _is_c_constant{ true };
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -493,6 +493,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
if(!_gemm_kernel_asm->B_is_pretransposed())
{
ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
if(is_fixed_format(wf))
{
@@ -501,17 +502,35 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// as a 2D tensor at arm_gemm level, where the rows are
// O'/<interleave_by> and the columns are <interleave_by> *
// H * W * I'.
- ITensorInfo *tensor_info = b->info();
- const DataLayout data_layout = tensor_info->data_layout();
- const TensorShape tensor_shape = tensor_info->tensor_shape();
- const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
- const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
- const int Ip = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
- const int interleave_by = arm_compute::interleave_by(wf);
- ldb = (interleave_by * H * W * Ip);
+ ITensorInfo *tensor_info = b->info();
+ const DataLayout data_layout = tensor_info->data_layout();
+ const TensorShape tensor_shape = tensor_info->tensor_shape();
+ const int tensor_height = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+ const int tensor_width = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+ const int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
+ const int interleave_by = arm_compute::interleave_by(wf);
+ // We need to find a new stride that is distance from the data for one
+ // set of output channels to the next
+ if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width)
+ {
+ // In this case dimensions that are packed are height, width and channel
+ // so we need to stride it by interleave_by
+ ldb = interleave_by * tensor_height * tensor_width * tensor_channels;
+ }
+ else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width))
+ {
+ // In this case dimension that is packed is only height
+ // so we need to stride only height by interleave_by
+ ldb = interleave_by * tensor_height;
+ }
+ else
+ {
+ // If dimensions are not packed as above error is thrown
+ // as at the moment other forms of packing are not supported
+ ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel");
+ }
}
- multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
- in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+ in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
}
// If necessary, run pretranspose every time if either weights or biases are non-constant