aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLPReluLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLPReluLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLPReluLayer.cpp30
1 files changed, 17 insertions, 13 deletions
diff --git a/src/runtime/CL/functions/CLPReluLayer.cpp b/src/runtime/CL/functions/CLPReluLayer.cpp
index fbb466acc8..e03bd13284 100644
--- a/src/runtime/CL/functions/CLPReluLayer.cpp
+++ b/src/runtime/CL/functions/CLPReluLayer.cpp
@@ -44,19 +44,22 @@ void configure_border_handler(const CLCompileContext &compile_context, CLFillBor
}
}
}
-void select_border_input(InputTensorMap &tensor_map, InputTensorMap &inputs, OutputTensorMap &outputs)
+
+ITensorPack select_border_input(ITensorPack &tensors)
{
- if(outputs.at(TensorType::ACL_DST)->info()->dimension(0) > 1)
+ ITensorPack pack;
+ if(tensors.get_tensor(TensorType::ACL_DST)->info()->dimension(0) > 1)
{
- if(inputs.at(TensorType::ACL_SRC_1)->info()->dimension(0) == 1)
+ if(tensors.get_const_tensor(TensorType::ACL_SRC_1)->info()->dimension(0) == 1)
{
- tensor_map[TensorType::ACL_SRC] = inputs.at(TensorType::ACL_SRC_1);
+ pack.add_tensor(TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1));
}
else
{
- tensor_map[TensorType::ACL_SRC] = inputs.at(TensorType::ACL_SRC_0);
+ pack.add_tensor(TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_0));
}
}
+ return pack;
}
} // namespace
@@ -80,12 +83,11 @@ Status CLPReluLayer::validate(const ITensorInfo *input, const ITensorInfo *alpha
return CLArithmeticOperationKernel::validate(ArithmeticOperation::PRELU, input, alpha, output);
}
-void CLPReluLayer::run(InputTensorMap inputs, OutputTensorMap outputs, OperatorTensorMap workspace)
+void CLPReluLayer::run(ITensorPack &tensors)
{
- InputTensorMap src;
- select_border_input(src, inputs, outputs);
- CLScheduler::get().enqueue_op(_border_handler, src, {});
- ICLOperator::run(inputs, outputs, workspace);
+ auto border_pack = select_border_input(tensors);
+ CLScheduler::get().enqueue_op(_border_handler, border_pack);
+ ICLOperator::run(tensors);
}
} // namespace experimental
@@ -126,9 +128,11 @@ Status CLPReluLayer::validate(const ITensorInfo *input, const ITensorInfo *alpha
void CLPReluLayer::run()
{
- const InputTensorMap src{ { TensorType::ACL_SRC_0, _impl->src_0 }, { TensorType::ACL_SRC_1, _impl->src_1 } };
- const OutputTensorMap dst{ { TensorType::ACL_DST, _impl->dst } };
+ ITensorPack pack;
+ pack.add_tensor(TensorType::ACL_SRC_0, _impl->src_0);
+ pack.add_tensor(TensorType::ACL_SRC_1, _impl->src_1);
+ pack.add_tensor(TensorType::ACL_DST, _impl->dst);
- _impl->op->run(src, dst, {});
+ _impl->op->run(pack);
}
} // namespace arm_compute