aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLConcatenateLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLConcatenateLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLConcatenateLayer.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/src/runtime/CL/functions/CLConcatenateLayer.cpp b/src/runtime/CL/functions/CLConcatenateLayer.cpp
index 0594a17a7a..1d396f5ebf 100644
--- a/src/runtime/CL/functions/CLConcatenateLayer.cpp
+++ b/src/runtime/CL/functions/CLConcatenateLayer.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_compute/runtime/CL/functions/CLConcatenateLayer.h"
+#include "arm_compute/core/CL/kernels/CLBatchConcatenateLayerKernel.h"
#include "arm_compute/core/CL/kernels/CLDepthConcatenateLayerKernel.h"
#include "arm_compute/core/CL/kernels/CLHeightConcatenateLayerKernel.h"
#include "arm_compute/core/CL/kernels/CLWidthConcatenate2TensorsKernel.h"
@@ -124,6 +125,17 @@ void CLConcatenateLayer::configure(const std::vector<ICLTensor *> &inputs_vector
}
break;
}
+ case 3:
+ {
+ for(unsigned int i = 0; i < _num_inputs; ++i)
+ {
+ auto kernel = support::cpp14::make_unique<CLBatchConcatenateLayerKernel>();
+ kernel->configure(inputs_vector.at(i), offset, output);
+ offset += inputs_vector.at(i)->info()->dimension(_axis);
+ _concat_kernels.emplace_back(std::move(kernel));
+ }
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Axis not supported");
}
@@ -184,6 +196,15 @@ Status CLConcatenateLayer::validate(const std::vector<ITensorInfo *> &inputs_vec
}
break;
}
+ case 3:
+ {
+ for(const auto &input : inputs_vector)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(CLBatchConcatenateLayerKernel::validate(input, offset, output));
+ offset += input->dimension(axis);
+ }
+ break;
+ }
default:
ARM_COMPUTE_ERROR("Axis not supported");
}