aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuWinogradConv2d.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuWinogradConv2d.h')
-rw-r--r--src/cpu/operators/CpuWinogradConv2d.h32
1 files changed, 14 insertions, 18 deletions
diff --git a/src/cpu/operators/CpuWinogradConv2d.h b/src/cpu/operators/CpuWinogradConv2d.h
index ba9b879431..03bfc51a46 100644
--- a/src/cpu/operators/CpuWinogradConv2d.h
+++ b/src/cpu/operators/CpuWinogradConv2d.h
@@ -29,8 +29,8 @@
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuOperator.h"
-#include "src/cpu/kernels/CpuWinogradConv2dKernel.h"
#include "src/cpu/kernels/assembly/gemm_common.hpp"
+#include "src/cpu/kernels/CpuWinogradConv2dKernel.h"
#include "src/cpu/operators/CpuActivation.h"
#include "src/cpu/operators/CpuGemm.h"
#include "src/cpu/operators/CpuPermute.h"
@@ -96,26 +96,22 @@ public:
bool enable_fast_math = false);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &constants) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &constants) override;
experimental::MemoryRequirements workspace() const override;
private:
enum AuxTensorIdx
{
- GemmWorkspace = 0,
- Pretranspose = 1,
- InterleavedLHS = 2,
- TransposedRHS = 3,
- TempResult = 4,
- TransformedInput = 5,
- TransformedOutput = 6,
- WorkspaceIO = 7,
- TransformedWeights = 8,
- PermutedWeights = 9,
- PermutedInput = TransformedOutput,
- PermutedOutput = TransformedInput,
- Count = 10
+ /** Slot 0 - 6 reserved for CpuGemm */
+ TransformedInput = 7,
+ TransformedOutput,
+ WorkspaceIO,
+ TransformedWeights,
+ PermutedWeights,
+ Count,
+ PermutedInput = TransformedOutput,
+ PermutedOutput = TransformedInput
};
std::unique_ptr<CpuGemm> _gemm_function;
std::unique_ptr<CpuActivation> _activation_func;
@@ -124,9 +120,9 @@ private:
std::unique_ptr<CpuPermute> _permute_input;
std::unique_ptr<CpuPermute> _permute_output;
std::unique_ptr<CpuPermute> _permute_weights;
- experimental::MemoryRequirements _aux_mem{ Count };
+ experimental::MemoryRequirements _aux_mem{Count};
std::unique_ptr<arm_conv::ConvolutionArgs>
- _conv_args; // Make it unique ptr because this type does not have a default constructor
+ _conv_args; // Make it unique ptr because this type does not have a default constructor
arm_conv::winograd::WinogradImpl _winograd_impl;
DataLayout _data_layout;
TensorInfo _winograd_transformed_input;