diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEConcatenateLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEConcatenateLayer.cpp | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/src/runtime/NEON/functions/NEConcatenateLayer.cpp b/src/runtime/NEON/functions/NEConcatenateLayer.cpp index 9f8a2a1b8e..8df4f4cb62 100644 --- a/src/runtime/NEON/functions/NEConcatenateLayer.cpp +++ b/src/runtime/NEON/functions/NEConcatenateLayer.cpp @@ -146,16 +146,14 @@ Status NEConcatenation::validate(const std::vector<const ITensorInfo *> &inputs_ return Status{}; } -void NEConcatenation::run(InputTensorMap inputs, OutputTensorMap outputs, OperatorTensorMap workspace) +void NEConcatenation::run(ITensorPack &tensors) { - ARM_COMPUTE_UNUSED(workspace); - - if(inputs.empty() || outputs.empty()) + if(tensors.empty()) { ARM_COMPUTE_ERROR("No inputs provided"); } - if(inputs.size() != _num_inputs) + if(static_cast<int>(tensors.size() - 1) != static_cast<int>(_num_inputs)) { ARM_COMPUTE_ERROR("Configured with different number of inputs"); } @@ -163,8 +161,10 @@ void NEConcatenation::run(InputTensorMap inputs, OutputTensorMap outputs, Operat int i = 0; for(auto &k : _concat_kernels) { - const InputTensorMap input = { { TensorType::ACL_SRC, inputs.at(ACL_SRC_VEC + i) } }; - NEScheduler::get().schedule_op(k.get(), Window::DimY, input, outputs); + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, tensors.get_const_tensor(ACL_SRC_VEC + i)); + pack.add_tensor(TensorType::ACL_DST, tensors.get_tensor(ACL_DST)); + NEScheduler::get().schedule_op(k.get(), Window::DimY, pack); ++i; } } @@ -216,13 +216,13 @@ Status NEConcatenateLayer::validate(const std::vector<const ITensorInfo *> &inpu void NEConcatenateLayer::run() { - InputTensorMap srcs; + ITensorPack pack; for(unsigned i = 0; i < _impl->num_inputs; ++i) { - srcs.insert(std::make_pair(TensorType::ACL_SRC_VEC + i, _impl->srcs.at(i))); + pack.add_tensor(TensorType::ACL_SRC_VEC + i, _impl->srcs.at(i)); } - const OutputTensorMap dst{ { TensorType::ACL_DST, _impl->dst } }; + pack.add_tensor(TensorType::ACL_DST, _impl->dst); - _impl->op->run(srcs, dst, {}); + _impl->op->run(pack); } } // namespace arm_compute |