diff options
author | Matthew Sloyan <matthew.sloyan@arm.com> | 2023-03-30 10:12:08 +0100 |
---|---|---|
committer | ryan.oshea3 <ryan.oshea3@arm.com> | 2023-04-05 20:36:32 +0000 |
commit | ebe392df1635790bf21714549adb97f2f75559e1 (patch) | |
tree | 6fb8e56cc755d7c47a62bbe72c54b6ca5445377d /delegate/common/src/test/DelegateTestInterpreterUtils.hpp | |
parent | ac9607f401dc30003aa97bd179a06d6b8a32139f (diff) | |
download | armnn-ebe392df1635790bf21714549adb97f2f75559e1.tar.gz |
IVGCVSW-7562 Implement DelegateTestInterpreter for classic delegate
* Updated all tests to use new DelegateTestInterpreter.
* Fixed some unit tests where the shape was incorrect.
* Add file identifier to FlatBuffersBuilder, as it is required for
validation when creating the model using new API.
Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Change-Id: I1c4f5464367b35d4528571fa94d14bfaef18fb4d
Diffstat (limited to 'delegate/common/src/test/DelegateTestInterpreterUtils.hpp')
-rw-r--r-- | delegate/common/src/test/DelegateTestInterpreterUtils.hpp | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/delegate/common/src/test/DelegateTestInterpreterUtils.hpp b/delegate/common/src/test/DelegateTestInterpreterUtils.hpp new file mode 100644 index 0000000000..396c75c22e --- /dev/null +++ b/delegate/common/src/test/DelegateTestInterpreterUtils.hpp @@ -0,0 +1,110 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/Exceptions.hpp> + +#include <tensorflow/lite/core/c/c_api.h> +#include <tensorflow/lite/kernels/custom_ops_register.h> +#include <tensorflow/lite/kernels/register.h> + +#include <type_traits> + +namespace delegateTestInterpreter +{ + +inline TfLiteTensor* GetInputTensorFromInterpreter(TfLiteInterpreter* interpreter, int index) +{ + TfLiteTensor* inputTensor = TfLiteInterpreterGetInputTensor(interpreter, index); + if(inputTensor == nullptr) + { + throw armnn::Exception("Input tensor was not found at the given index: " + std::to_string(index)); + } + return inputTensor; +} + +inline const TfLiteTensor* GetOutputTensorFromInterpreter(TfLiteInterpreter* interpreter, int index) +{ + const TfLiteTensor* outputTensor = TfLiteInterpreterGetOutputTensor(interpreter, index); + if(outputTensor == nullptr) + { + throw armnn::Exception("Output tensor was not found at the given index: " + std::to_string(index)); + } + return outputTensor; +} + +inline TfLiteModel* CreateTfLiteModel(std::vector<char>& data) +{ + TfLiteModel* tfLiteModel = TfLiteModelCreate(data.data(), data.size()); + if(tfLiteModel == nullptr) + { + throw armnn::Exception("An error has occurred when creating the TfLiteModel."); + } + return tfLiteModel; +} + +inline TfLiteInterpreterOptions* CreateTfLiteInterpreterOptions() +{ + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + if(options == nullptr) + { + throw armnn::Exception("An error has occurred when creating the TfLiteInterpreterOptions."); + } + return options; +} + +inline tflite::ops::builtin::BuiltinOpResolver GenerateCustomOpResolver(const std::string& opName) +{ + tflite::ops::builtin::BuiltinOpResolver opResolver; + if (opName == "MaxPool3D") + { + opResolver.AddCustom("MaxPool3D", tflite::ops::custom::Register_MAX_POOL_3D()); + } + else if (opName == "AveragePool3D") + { + opResolver.AddCustom("AveragePool3D", tflite::ops::custom::Register_AVG_POOL_3D()); + } + else + { + throw armnn::Exception("The custom op isn't supported by the DelegateTestInterpreter."); + } + return opResolver; +} + +template<typename T> +inline TfLiteStatus CopyFromBufferToTensor(TfLiteTensor* tensor, std::vector<T>& values) +{ + // Make sure there is enough bytes allocated to copy into for uint8_t and int16_t case. + if(tensor->bytes < values.size() * sizeof(T)) + { + throw armnn::Exception("Tensor has not been allocated to match number of values."); + } + + // Requires uint8_t and int16_t specific case as the number of bytes is larger than values passed when creating + // TFLite tensors of these types. Otherwise, use generic TfLiteTensorCopyFromBuffer function. + TfLiteStatus status = kTfLiteOk; + if (std::is_same<T, uint8_t>::value) + { + for (unsigned int i = 0; i < values.size(); ++i) + { + tensor->data.uint8[i] = values[i]; + } + } + else if (std::is_same<T, int16_t>::value) + { + for (unsigned int i = 0; i < values.size(); ++i) + { + tensor->data.i16[i] = values[i]; + } + } + else + { + status = TfLiteTensorCopyFromBuffer(tensor, values.data(), values.size() * sizeof(T)); + } + return status; +} + +} // anonymous namespace
\ No newline at end of file |