diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEConcatenateLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEConcatenateLayer.cpp | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/src/runtime/NEON/functions/NEConcatenateLayer.cpp b/src/runtime/NEON/functions/NEConcatenateLayer.cpp index d338493e51..9a70d32843 100644 --- a/src/runtime/NEON/functions/NEConcatenateLayer.cpp +++ b/src/runtime/NEON/functions/NEConcatenateLayer.cpp @@ -23,6 +23,7 @@ */ #include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h" +#include "arm_compute/core/NEON/kernels/NEBatchConcatenateLayerKernel.h" #include "arm_compute/core/NEON/kernels/NEDepthConcatenateLayerKernel.h" #include "arm_compute/core/NEON/kernels/NEHeightConcatenateLayerKernel.h" #include "arm_compute/core/NEON/kernels/NEWidthConcatenateLayerKernel.h" @@ -112,6 +113,13 @@ void NEConcatenateLayer::configure_internal(std::vector<TensorType *> &&inputs_v _concat_kernels.emplace_back(std::move(kernel)); break; } + case 3: + { + auto kernel = support::cpp14::make_unique<NEBatchConcatenateLayerKernel>(); + kernel->configure(inputs_vector.at(i), offset, output); + _concat_kernels.emplace_back(std::move(kernel)); + break; + } default: ARM_COMPUTE_ERROR("Axis not supported"); } @@ -146,6 +154,11 @@ Status NEConcatenateLayer::validate_internal(const std::vector<TensorInfoType *> ARM_COMPUTE_RETURN_ON_ERROR(NEDepthConcatenateLayerKernel::validate(input, offset, output)); break; } + case 3: + { + ARM_COMPUTE_RETURN_ON_ERROR(NEBatchConcatenateLayerKernel::validate(input, offset, output)); + break; + } default: ARM_COMPUTE_ERROR("Axis not supported"); } |