aboutsummaryrefslogtreecommitdiff
path: root/utils/GraphUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'utils/GraphUtils.cpp')
-rw-r--r--utils/GraphUtils.cpp51
1 files changed, 43 insertions, 8 deletions
diff --git a/utils/GraphUtils.cpp b/utils/GraphUtils.cpp
index 00165cd6c2..3646facab2 100644
--- a/utils/GraphUtils.cpp
+++ b/utils/GraphUtils.cpp
@@ -63,17 +63,33 @@ TFPreproccessor::TFPreproccessor(float min_range, float max_range)
}
void TFPreproccessor::preprocess(ITensor &tensor)
{
+ if(tensor.info()->data_type() == DataType::F32)
+ {
+ preprocess_typed<float>(tensor);
+ }
+ else if(tensor.info()->data_type() == DataType::F16)
+ {
+ preprocess_typed<half>(tensor);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("NOT SUPPORTED!");
+ }
+}
+
+template <typename T>
+void TFPreproccessor::preprocess_typed(ITensor &tensor)
+{
Window window;
window.use_tensor_dimensions(tensor.info()->tensor_shape());
const float range = _max_range - _min_range;
-
execute_window_loop(window, [&](const Coordinates & id)
{
- const float value = *reinterpret_cast<float *>(tensor.ptr_to_element(id));
- float res = value / 255.f; // Normalize to [0, 1]
- res = res * range + _min_range; // Map to [min_range, max_range]
- *reinterpret_cast<float *>(tensor.ptr_to_element(id)) = res;
+ const T value = *reinterpret_cast<T *>(tensor.ptr_to_element(id));
+ float res = value / 255.f; // Normalize to [0, 1]
+ res = res * range + _min_range; // Map to [min_range, max_range]
+ *reinterpret_cast<T *>(tensor.ptr_to_element(id)) = res;
});
}
@@ -88,15 +104,31 @@ CaffePreproccessor::CaffePreproccessor(std::array<float, 3> mean, bool bgr, floa
void CaffePreproccessor::preprocess(ITensor &tensor)
{
+ if(tensor.info()->data_type() == DataType::F32)
+ {
+ preprocess_typed<float>(tensor);
+ }
+ else if(tensor.info()->data_type() == DataType::F16)
+ {
+ preprocess_typed<half>(tensor);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("NOT SUPPORTED!");
+ }
+}
+
+template <typename T>
+void CaffePreproccessor::preprocess_typed(ITensor &tensor)
+{
Window window;
window.use_tensor_dimensions(tensor.info()->tensor_shape());
-
const int channel_idx = get_data_layout_dimension_index(tensor.info()->data_layout(), DataLayoutDimension::CHANNEL);
execute_window_loop(window, [&](const Coordinates & id)
{
- const float value = *reinterpret_cast<float *>(tensor.ptr_to_element(id)) - _mean[id[channel_idx]];
- *reinterpret_cast<float *>(tensor.ptr_to_element(id)) = value * _scale;
+ const T value = *reinterpret_cast<T *>(tensor.ptr_to_element(id)) - T(_mean[id[channel_idx]]);
+ *reinterpret_cast<T *>(tensor.ptr_to_element(id)) = value * T(_scale);
});
}
@@ -370,6 +402,9 @@ bool ValidationOutputAccessor::access_tensor(arm_compute::ITensor &tensor)
case DataType::QASYMM8:
tensor_results = access_predictions_tensor<uint8_t>(tensor);
break;
+ case DataType::F16:
+ tensor_results = access_predictions_tensor<half>(tensor);
+ break;
case DataType::F32:
tensor_results = access_predictions_tensor<float>(tensor);
break;