Skip to content

Commit 1b58b40

Browse files
author
Pedro Silva
committed
1 parent 878f195 commit 1b58b40

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

verde/tests/test_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,22 @@
2727
meshgrid_to_1d,
2828
parse_engine,
2929
partition_by_sum,
30+
fill_nans
3031
)
3132

3233

34+
def test_fill_nans():
35+
"""
36+
This function tests the fill_nans function.
37+
"""
38+
39+
grid = np.array([[1, np.nan, 3],
40+
[4, 5, np.nan],
41+
[np.nan, 7, 8]])
42+
filled_grid = fill_nans(grid)
43+
assert np.isnan(filled_grid).sum() == 0
44+
45+
3346
def test_parse_engine():
3447
"Check that it works for common input"
3548
assert parse_engine("numba") == "numba"

verde/utils.py

+31
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import pandas as pd
1515
import xarray as xr
1616
from scipy.spatial import cKDTree
17+
from sklearn.impute import KNNImputer
1718

1819
try:
1920
from pykdtree.kdtree import KDTree as pyKDTree
@@ -681,6 +682,36 @@ def kdtree(coordinates, use_pykdtree=True, **kwargs):
681682
return tree
682683

683684

685+
def fill_nans(grid, n_neighbors=1):
686+
"""
687+
This methos is responsible for fill the NaN values in the grid using the KNN algorithm.
688+
689+
Parameters
690+
----------
691+
grid : :class:`xarray.Dataset` or :class:`xarray.DataArray`
692+
A 2D grid with one or more data variables.
693+
n_neighbors : int
694+
Number of nearest neighbors to use to fill the NaN values in the grid.
695+
The greater the quantity, the longer the processing time, depending on the size of the matrix
696+
697+
Returns
698+
-------
699+
grid : :class:`xarray.Dataset` or :class:`xarray.DataArray`
700+
A 2D grid with the NaN values filled.
701+
"""
702+
703+
not_nan_values = np.argwhere(~np.isnan(grid)).reshape(-1, 1)
704+
unknown_indices = np.argwhere(np.isnan(grid))
705+
706+
knn_imputer = KNNImputer(n_neighbors=n_neighbors)
707+
knn_imputer.fit(not_nan_values)
708+
709+
predicted_values = knn_imputer.transform(not_nan_values)
710+
for i, idx in enumerate(unknown_indices):
711+
grid[tuple(idx)] = predicted_values[i]
712+
713+
return grid
714+
684715
def partition_by_sum(array, parts):
685716
"""
686717
Partition an array into parts of approximately equal sum.

0 commit comments

Comments
 (0)