diff options
Diffstat (limited to 'utils/GraphUtils.cpp')
-rw-r--r-- | utils/GraphUtils.cpp | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/utils/GraphUtils.cpp b/utils/GraphUtils.cpp index 145e44950b..0edb6f2a56 100644 --- a/utils/GraphUtils.cpp +++ b/utils/GraphUtils.cpp @@ -129,6 +129,45 @@ bool DummyAccessor::access_tensor(ITensor &tensor) return ret; } +NumPyAccessor::NumPyAccessor(std::string npy_path, TensorShape shape, DataType data_type, std::ostream &output_stream) + : _npy_tensor(), _filename(std::move(npy_path)), _output_stream(output_stream) +{ + NumPyBinLoader loader(_filename); + + TensorInfo info(shape, 1, data_type); + _npy_tensor.allocator()->init(info); + _npy_tensor.allocator()->allocate(); + + loader.access_tensor(_npy_tensor); +} + +template <typename T> +void NumPyAccessor::access_numpy_tensor(ITensor &tensor) +{ + const int num_elements = tensor.info()->total_size(); + int num_mismatches = utils::compare_tensor<T>(tensor, _npy_tensor); + float percentage_mismatches = static_cast<float>(num_mismatches) / num_elements; + + _output_stream << "Results: " << 100.f - (percentage_mismatches * 100) << " % matches with the provided output[" << _filename << "]." << std::endl; +} + +bool NumPyAccessor::access_tensor(ITensor &tensor) +{ + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::F32); + ARM_COMPUTE_ERROR_ON(_npy_tensor.info()->dimension(0) != tensor.info()->dimension(0)); + + switch(tensor.info()->data_type()) + { + case DataType::F32: + access_numpy_tensor<float>(tensor); + break; + default: + ARM_COMPUTE_ERROR("NOT SUPPORTED!"); + } + + return false; +} + PPMAccessor::PPMAccessor(std::string ppm_path, bool bgr, std::unique_ptr<IPreprocessor> preprocessor) : _ppm_path(std::move(ppm_path)), _bgr(bgr), _preprocessor(std::move(preprocessor)) { |