diff options
Diffstat (limited to 'src/cpu/operators/CpuWinogradConv2d.h')
-rw-r--r-- | src/cpu/operators/CpuWinogradConv2d.h | 32 |
1 files changed, 14 insertions, 18 deletions
diff --git a/src/cpu/operators/CpuWinogradConv2d.h b/src/cpu/operators/CpuWinogradConv2d.h index ba9b879431..03bfc51a46 100644 --- a/src/cpu/operators/CpuWinogradConv2d.h +++ b/src/cpu/operators/CpuWinogradConv2d.h @@ -29,8 +29,8 @@ #include "src/core/common/Macros.h" #include "src/cpu/ICpuOperator.h" -#include "src/cpu/kernels/CpuWinogradConv2dKernel.h" #include "src/cpu/kernels/assembly/gemm_common.hpp" +#include "src/cpu/kernels/CpuWinogradConv2dKernel.h" #include "src/cpu/operators/CpuActivation.h" #include "src/cpu/operators/CpuGemm.h" #include "src/cpu/operators/CpuPermute.h" @@ -96,26 +96,22 @@ public: bool enable_fast_math = false); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &constants) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &constants) override; experimental::MemoryRequirements workspace() const override; private: enum AuxTensorIdx { - GemmWorkspace = 0, - Pretranspose = 1, - InterleavedLHS = 2, - TransposedRHS = 3, - TempResult = 4, - TransformedInput = 5, - TransformedOutput = 6, - WorkspaceIO = 7, - TransformedWeights = 8, - PermutedWeights = 9, - PermutedInput = TransformedOutput, - PermutedOutput = TransformedInput, - Count = 10 + /** Slot 0 - 6 reserved for CpuGemm */ + TransformedInput = 7, + TransformedOutput, + WorkspaceIO, + TransformedWeights, + PermutedWeights, + Count, + PermutedInput = TransformedOutput, + PermutedOutput = TransformedInput }; std::unique_ptr<CpuGemm> _gemm_function; std::unique_ptr<CpuActivation> _activation_func; @@ -124,9 +120,9 @@ private: std::unique_ptr<CpuPermute> _permute_input; std::unique_ptr<CpuPermute> _permute_output; std::unique_ptr<CpuPermute> _permute_weights; - experimental::MemoryRequirements _aux_mem{ Count }; + experimental::MemoryRequirements _aux_mem{Count}; std::unique_ptr<arm_conv::ConvolutionArgs> - _conv_args; // Make it unique ptr because this type does not have a default constructor + _conv_args; // Make it unique ptr because this type does not have a default constructor arm_conv::winograd::WinogradImpl _winograd_impl; DataLayout _data_layout; TensorInfo _winograd_transformed_input; |