diff options
Diffstat (limited to 'Utils.cpp')
-rw-r--r-- | Utils.cpp | 65 |
1 files changed, 64 insertions, 1 deletions
@@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2021,2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -767,4 +767,67 @@ void CommitPools(std::vector<::android::nn::RunTimePoolInfo>& memPools) #endif } } + +size_t GetSize(const V1_0::Request& request, const V1_0::RequestArgument& requestArgument) +{ + return request.pools[requestArgument.location.poolIndex].size(); +} + +#ifdef ARMNN_ANDROID_NN_V1_3 +size_t GetSize(const V1_3::Request& request, const V1_0::RequestArgument& requestArgument) +{ + if (request.pools[requestArgument.location.poolIndex].getDiscriminator() == + V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory) + { + return request.pools[requestArgument.location.poolIndex].hidlMemory().size(); + } + else + { + return 0; + } +} +#endif + +template <typename ErrorStatus, typename Request> +ErrorStatus ValidateRequestArgument(const Request& request, + const armnn::TensorInfo& tensorInfo, + const V1_0::RequestArgument& requestArgument, + std::string descString) +{ + if (requestArgument.location.poolIndex >= request.pools.size()) + { + std::string err = fmt::format("Invalid {} pool at index {} the pool index is greater than the number " + "of available pools {}", + descString, requestArgument.location.poolIndex, request.pools.size()); + ALOGE(err.c_str()); + return ErrorStatus::GENERAL_FAILURE; + } + const size_t size = GetSize(request, requestArgument); + size_t totalLength = tensorInfo.GetNumBytes(); + + if (static_cast<size_t>(requestArgument.location.offset) + totalLength > size) + { + std::string err = fmt::format("Invalid {} pool at index {} the offset {} and length {} are greater " + "than the pool size {}", descString, requestArgument.location.poolIndex, + requestArgument.location.offset, totalLength, size); + ALOGE(err.c_str()); + return ErrorStatus::GENERAL_FAILURE; + } + return ErrorStatus::NONE; +} + +template V1_0::ErrorStatus ValidateRequestArgument<V1_0::ErrorStatus, V1_0::Request>( + const V1_0::Request& request, + const armnn::TensorInfo& tensorInfo, + const V1_0::RequestArgument& requestArgument, + std::string descString); + +#ifdef ARMNN_ANDROID_NN_V1_3 +template V1_3::ErrorStatus ValidateRequestArgument<V1_3::ErrorStatus, V1_3::Request>( + const V1_3::Request& request, + const armnn::TensorInfo& tensorInfo, + const V1_0::RequestArgument& requestArgument, + std::string descString); +#endif + } // namespace armnn_driver |