diff options
Diffstat (limited to 'src/cpu')
-rw-r--r-- | src/cpu/kernels/CpuIm2ColKernel.cpp | 10 | ||||
-rw-r--r-- | src/cpu/kernels/cast/generic/neon/bfloat16.cpp | 4 | ||||
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 26 |
3 files changed, 20 insertions, 20 deletions
diff --git a/src/cpu/kernels/CpuIm2ColKernel.cpp b/src/cpu/kernels/CpuIm2ColKernel.cpp index 875d66594f..25ff6c291c 100644 --- a/src/cpu/kernels/CpuIm2ColKernel.cpp +++ b/src/cpu/kernels/CpuIm2ColKernel.cpp @@ -359,11 +359,11 @@ void CpuIm2ColKernel::configure(const ITensorInfo *src, ITensorInfo *dst, const case DataType::F32: _func = (!conv_info.has_padding()) ? &CpuIm2ColKernel::run_im2col<float, false, true> : &CpuIm2ColKernel::run_im2col<float, true, true>; break; -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: _func = (!conv_info.has_padding()) ? &CpuIm2ColKernel::run_im2col<bfloat16, false, true> : &CpuIm2ColKernel::run_im2col<bfloat16, true, true>; break; -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: _func = (!conv_info.has_padding()) ? &CpuIm2ColKernel::run_im2col<float16_t, false, true> : &CpuIm2ColKernel::run_im2col<float16_t, true, true>; @@ -385,11 +385,11 @@ void CpuIm2ColKernel::configure(const ITensorInfo *src, ITensorInfo *dst, const case DataType::F32: _func = (!conv_info.has_padding()) ? &CpuIm2ColKernel::run_im2col<float, false, false> : &CpuIm2ColKernel::run_im2col<float, true, false>; break; -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: _func = (!conv_info.has_padding()) ? &CpuIm2ColKernel::run_im2col<bfloat16, false, false> : &CpuIm2ColKernel::run_im2col<bfloat16, true, false>; break; -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: _func = (!conv_info.has_padding()) ? &CpuIm2ColKernel::run_im2col<float16_t, false, false> : &CpuIm2ColKernel::run_im2col<float16_t, true, false>; @@ -453,4 +453,4 @@ size_t CpuIm2ColKernel::get_mws(const CPUInfo &platform, size_t thread_count) co } } // namespace kernels } // namespace cpu -} // namespace arm_compute
\ No newline at end of file +} // namespace arm_compute diff --git a/src/cpu/kernels/cast/generic/neon/bfloat16.cpp b/src/cpu/kernels/cast/generic/neon/bfloat16.cpp index aac4ef4ca0..eed537039f 100644 --- a/src/cpu/kernels/cast/generic/neon/bfloat16.cpp +++ b/src/cpu/kernels/cast/generic/neon/bfloat16.cpp @@ -21,7 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) #include "arm_compute/core/TensorInfo.h" #include "src/core/NEON/wrapper/wrapper.h" @@ -142,4 +142,4 @@ void neon_bfloat16_to_fp32_cast(const ITensor *_src, ITensor *_dst, const Thread } // namespace cpu } // namespace arm_compute -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 5694a3d9ee..558ff41a5c 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -156,8 +156,8 @@ public: const std::vector<int32_t> &multipliers); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; bool is_configured() const override; experimental::MemoryRequirements workspace() const override; bool isVarWeightsKernel() const override @@ -210,12 +210,12 @@ private: /** Indirect buffer */ std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{}; std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{}; - std::vector<TypeInput> _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; - experimental::MemoryRequirements _aux_mem{ Count }; - bool _B_pretranspose_required{ false }; - bool _is_b_constant{ true }; - bool _is_c_constant{ true }; + std::vector<TypeInput> _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + experimental::MemoryRequirements _aux_mem{ Count }; + bool _B_pretranspose_required{ false }; + bool _is_b_constant{ true }; + bool _is_c_constant{ true }; }; template <typename TypeInput, typename TypeOutput, class OutputStage> @@ -712,14 +712,14 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_we } break; #endif /* __aarch64__ */ -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(expected_weight_format, args, {})), "We could not find an optimized kernel for BFLOAT16 input and F32 output"); break; } -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(expected_weight_format, args, {})), @@ -821,11 +821,11 @@ void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo } break; #endif /* __aarch64__ */ -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info); break; -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); |