aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEConcatenateLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEConcatenateLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NEConcatenateLayer.cpp13
1 files changed, 13 insertions, 0 deletions
diff --git a/src/runtime/NEON/functions/NEConcatenateLayer.cpp b/src/runtime/NEON/functions/NEConcatenateLayer.cpp
index d338493e51..9a70d32843 100644
--- a/src/runtime/NEON/functions/NEConcatenateLayer.cpp
+++ b/src/runtime/NEON/functions/NEConcatenateLayer.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h"
+#include "arm_compute/core/NEON/kernels/NEBatchConcatenateLayerKernel.h"
#include "arm_compute/core/NEON/kernels/NEDepthConcatenateLayerKernel.h"
#include "arm_compute/core/NEON/kernels/NEHeightConcatenateLayerKernel.h"
#include "arm_compute/core/NEON/kernels/NEWidthConcatenateLayerKernel.h"
@@ -112,6 +113,13 @@ void NEConcatenateLayer::configure_internal(std::vector<TensorType *> &&inputs_v
_concat_kernels.emplace_back(std::move(kernel));
break;
}
+ case 3:
+ {
+ auto kernel = support::cpp14::make_unique<NEBatchConcatenateLayerKernel>();
+ kernel->configure(inputs_vector.at(i), offset, output);
+ _concat_kernels.emplace_back(std::move(kernel));
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Axis not supported");
}
@@ -146,6 +154,11 @@ Status NEConcatenateLayer::validate_internal(const std::vector<TensorInfoType *>
ARM_COMPUTE_RETURN_ON_ERROR(NEDepthConcatenateLayerKernel::validate(input, offset, output));
break;
}
+ case 3:
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NEBatchConcatenateLayerKernel::validate(input, offset, output));
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Axis not supported");
}