aboutsummaryrefslogtreecommitdiff
path: root/utils/GraphUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'utils/GraphUtils.cpp')
-rw-r--r--utils/GraphUtils.cpp39
1 files changed, 39 insertions, 0 deletions
diff --git a/utils/GraphUtils.cpp b/utils/GraphUtils.cpp
index 145e44950b..0edb6f2a56 100644
--- a/utils/GraphUtils.cpp
+++ b/utils/GraphUtils.cpp
@@ -129,6 +129,45 @@ bool DummyAccessor::access_tensor(ITensor &tensor)
return ret;
}
+NumPyAccessor::NumPyAccessor(std::string npy_path, TensorShape shape, DataType data_type, std::ostream &output_stream)
+ : _npy_tensor(), _filename(std::move(npy_path)), _output_stream(output_stream)
+{
+ NumPyBinLoader loader(_filename);
+
+ TensorInfo info(shape, 1, data_type);
+ _npy_tensor.allocator()->init(info);
+ _npy_tensor.allocator()->allocate();
+
+ loader.access_tensor(_npy_tensor);
+}
+
+template <typename T>
+void NumPyAccessor::access_numpy_tensor(ITensor &tensor)
+{
+ const int num_elements = tensor.info()->total_size();
+ int num_mismatches = utils::compare_tensor<T>(tensor, _npy_tensor);
+ float percentage_mismatches = static_cast<float>(num_mismatches) / num_elements;
+
+ _output_stream << "Results: " << 100.f - (percentage_mismatches * 100) << " % matches with the provided output[" << _filename << "]." << std::endl;
+}
+
+bool NumPyAccessor::access_tensor(ITensor &tensor)
+{
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::F32);
+ ARM_COMPUTE_ERROR_ON(_npy_tensor.info()->dimension(0) != tensor.info()->dimension(0));
+
+ switch(tensor.info()->data_type())
+ {
+ case DataType::F32:
+ access_numpy_tensor<float>(tensor);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("NOT SUPPORTED!");
+ }
+
+ return false;
+}
+
PPMAccessor::PPMAccessor(std::string ppm_path, bool bgr, std::unique_ptr<IPreprocessor> preprocessor)
: _ppm_path(std::move(ppm_path)), _bgr(bgr), _preprocessor(std::move(preprocessor))
{