diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/GraphUtils.cpp | 13 | ||||
-rw-r--r-- | utils/GraphUtils.h | 4 | ||||
-rw-r--r-- | utils/Utils.h | 11 |
3 files changed, 17 insertions, 11 deletions
diff --git a/utils/GraphUtils.cpp b/utils/GraphUtils.cpp index b714c55136..26ea02a9ff 100644 --- a/utils/GraphUtils.cpp +++ b/utils/GraphUtils.cpp @@ -153,24 +153,29 @@ NumPyAccessor::NumPyAccessor(std::string npy_path, TensorShape shape, DataType d } template <typename T> -void NumPyAccessor::access_numpy_tensor(ITensor &tensor) +void NumPyAccessor::access_numpy_tensor(ITensor &tensor, T tolerance) { const int num_elements = tensor.info()->tensor_shape().total_size(); - int num_mismatches = utils::compare_tensor<T>(tensor, _npy_tensor); + int num_mismatches = utils::compare_tensor<T>(tensor, _npy_tensor, tolerance); float percentage_mismatches = static_cast<float>(num_mismatches) / num_elements; _output_stream << "Results: " << 100.f - (percentage_mismatches * 100) << " % matches with the provided output[" << _filename << "]." << std::endl; + _output_stream << " " << num_elements - num_mismatches << " out of " << num_elements << " matches with the provided output[" << _filename << "]." << std::endl + << std::endl; } bool NumPyAccessor::access_tensor(ITensor &tensor) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::F32, DataType::QASYMM8); ARM_COMPUTE_ERROR_ON(_npy_tensor.info()->dimension(0) != tensor.info()->dimension(0)); switch(tensor.info()->data_type()) { + case DataType::QASYMM8: + access_numpy_tensor<qasymm8_t>(tensor, 0); + break; case DataType::F32: - access_numpy_tensor<float>(tensor); + access_numpy_tensor<float>(tensor, 0.0001f); break; default: ARM_COMPUTE_ERROR("NOT SUPPORTED!"); diff --git a/utils/GraphUtils.h b/utils/GraphUtils.h index 131378e5bd..47656766a6 100644 --- a/utils/GraphUtils.h +++ b/utils/GraphUtils.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -160,7 +160,7 @@ public: private: template <typename T> - void access_numpy_tensor(ITensor &tensor); + void access_numpy_tensor(ITensor &tensor, T tolerance); Tensor _npy_tensor; const std::string _filename; diff --git a/utils/Utils.h b/utils/Utils.h index 04ccc3e812..788ae4eeb7 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -782,15 +782,16 @@ void init_sgemm_output(T &dst, T &src0, T &src1, arm_compute::DataType dt) */ uint64_t get_mem_free_from_meminfo(); -/** Compare to tensor +/** Compare two tensors * - * @param[in] tensor1 First tensor to be compared. - * @param[in] tensor2 Second tensor to be compared. + * @param[in] tensor1 First tensor to be compared. + * @param[in] tensor2 Second tensor to be compared. + * @param[in] tolerance Tolerance used for the comparison. * * @return The number of mismatches */ template <typename T> -int compare_tensor(ITensor &tensor1, ITensor &tensor2) +int compare_tensor(ITensor &tensor1, ITensor &tensor2, T tolerance) { ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&tensor1, &tensor2); ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(&tensor1, &tensor2); @@ -807,7 +808,7 @@ int compare_tensor(ITensor &tensor1, ITensor &tensor2) execute_window_loop(window, [&](const Coordinates & id) { - if(std::abs(*reinterpret_cast<T *>(itensor1.ptr()) - *reinterpret_cast<T *>(itensor2.ptr())) > 0.0001) + if(std::abs(*reinterpret_cast<T *>(itensor1.ptr()) - *reinterpret_cast<T *>(itensor2.ptr())) > tolerance) { ++num_mismatches; } |