Skip to content

Commit

Permalink
fix some part of fdb backend for cyclic axes
Browse files Browse the repository at this point in the history
  • Loading branch information
mathleur committed Jun 25, 2024
1 parent d4b8211 commit 73facc3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 15 deletions.
3 changes: 3 additions & 0 deletions examples/3D_shipping_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class Test:
def setup_method(self):
ds = data.from_source("file", "./examples/data/winds.grib")
array = ds.to_xarray()
print(array.number)
print(array.surface)
print(array.time)
array = array.isel(time=0).isel(surface=0).isel(number=0).u10
self.array = array
self.slicer = HullSlicer()
Expand Down
70 changes: 56 additions & 14 deletions polytope/datacube/backends/fdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def get(self, requests: TensorIndexTree):
total_request_decoding_info = []
total_uncompressed_requests = []
time0 = time.time()
print("THE COMPRESSED REQUESTS HERE")
print(fdb_requests)
# print("THE COMPRESSED REQUESTS HERE")
# print(fdb_requests)
for j, compressed_request in enumerate(fdb_requests):
uncompressed_request = {}

Expand Down Expand Up @@ -117,6 +117,7 @@ def get(self, requests: TensorIndexTree):
# print(total_uncompressed_requests)
print("GJ TIME")
print(time.time() - time1)
# print(output_values)
time2 = time.time()
self.assign_fdb_output_to_nodes(output_values, total_request_decoding_info)
print("ASSIGN GJ DATA TO RIGHT NODES")
Expand Down Expand Up @@ -164,10 +165,10 @@ def get_fdb_requests(
)
self.bunching_up_request_time += time.time() - time2
time1 = time.time()
print("AND NOW LOOK NOW")
print(current_start_idxs)
(original_indices, sorted_request_ranges) = self.sort_fdb_request_ranges(
range_lengths, current_start_idxs, lat_length
# print("AND NOW LOOK NOW")
# print(current_start_idxs)
(original_indices, sorted_request_ranges, fdb_node_ranges, current_start_idxs) = self.sort_fdb_request_ranges(
range_lengths, current_start_idxs, lat_length, fdb_node_ranges
)
self.request_sorting_time += time.time() - time1
time3 = time.time()
Expand Down Expand Up @@ -338,47 +339,88 @@ def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info):
# print(time.time() - time1)
# self.sorting_time +=time.time() - time1
# time1 = time.time()
# print("LOOK NOW REALLY")
# print(fdb_node_ranges)
# print(current_start_idxs)
for i in range(len(sorted_fdb_range_nodes)):
# for k in range(sorted_range_lengths[i]):
n = sorted_fdb_range_nodes[i]
interm_request_output_values = request_output_values[0][i][0]
# interm_request_output_values = request_output_values[0][i][0]
# TODO: k again??
for j in range(len(sorted_current_start_idxs[i])):
# n = sorted_fdb_range_nodes[i]
m = n[j][0]
time1 = time.time()
# time1 = time.time()
m.result.append(request_output_values[0][i][0][j])
self.sorting_time += time.time() - time1
# self.sorting_time += time.time() - time1

def sort_fdb_request_ranges(self, range_lengths, current_start_idx, lat_length):
def sort_fdb_request_ranges(self, range_lengths, current_start_idx, lat_length, fdb_node_ranges):
# print("WHAT IS THE CURREENT START IDX")
# print(current_start_idx)
interm_request_ranges = []
# print("WHAT ARE THE NODE RANGES?")
# print(fdb_node_ranges)
# TODO: modify the start indexes to have as many arrays as the request ranges
new_fdb_node_ranges = []
new_current_start_idx = []
for i in range(lat_length):
interm_fdb_nodes = fdb_node_ranges[i]
interm_start_idx = current_start_idx[i]
# print("WHAT IS THE CURREENT START IDX")
# print(interm_start_idx)
# for j in range(len(range_lengths[i])):
if True:
# if current_start_idx[i][0] is not None:
if True:
print(current_start_idx[i][-1]+1 - current_start_idx[i][0])
print(len(current_start_idx[i]))
if current_start_idx[i][-1]+1 - current_start_idx[i][0] <= len(current_start_idx[i]):
# print(current_start_idx[i][-1]+1 - current_start_idx[i][0])
# print(len(current_start_idx[i]))
if abs(current_start_idx[i][-1]+1 - current_start_idx[i][0]) <= len(current_start_idx[i]):
# print("WE DID NOT DIVIDE THE IDX RANGES")
current_request_ranges = (current_start_idx[i][0], current_start_idx[i][-1]+1)
# print(current_request_ranges)
interm_request_ranges.append(current_request_ranges)
new_fdb_node_ranges.append(interm_fdb_nodes)
new_current_start_idx.append(interm_start_idx)
else:
time0 = time.time()
# TODO: see where we have jump in indices and separate the ranges there
jumps = list(map(operator.sub, current_start_idx[i][1:], current_start_idx[i][:-1]))
last_idx = 0
for j, jump in enumerate(jumps):
# new_interm_fdb_nodes = []
# new_interm_start_idx = []
if jump > 1:
current_request_ranges = (current_start_idx[i][last_idx], current_start_idx[i][j]+1)
# new_interm_fdb_nodes.append(interm_fdb_nodes[last_idx:j + 1])
# new_interm_start_idx.append(interm_start_idx[last_idx:j + 1])
new_fdb_node_ranges.append(interm_fdb_nodes[last_idx:j + 1])
new_current_start_idx.append(interm_start_idx[last_idx:j + 1])
last_idx = j+1
interm_request_ranges.append(current_request_ranges)
# print("DID WE NOT ADD HERE?")
# print(new_interm_start_idx)
# print(interm_fdb_nodes)
# print(last_idx)
# print(j)
# new_interm_fdb_nodes.append(interm_fdb_nodes[last_idx:j])
# new_interm_start_idx.append(interm_start_idx[last_idx:j])
if j == len(current_start_idx[i]) - 2:
current_request_ranges = (current_start_idx[i][last_idx], current_start_idx[i][-1]+1)
interm_request_ranges.append(current_request_ranges)
# new_interm_fdb_nodes.append(interm_fdb_nodes[last_idx:])
# new_interm_start_idx.append(interm_start_idx[last_idx:])
new_fdb_node_ranges.append(interm_fdb_nodes[last_idx:])
new_current_start_idx.append(interm_start_idx[last_idx:])
print("TIME FOR CONSTRUCTING THE JUMP RANGES")
print(time.time() - time0)
request_ranges_with_idx = list(enumerate(interm_request_ranges))
sorted_list = sorted(request_ranges_with_idx, key=lambda x: x[1][0])
original_indices, sorted_request_ranges = zip(*sorted_list)
return (original_indices, sorted_request_ranges)
# print("INSIDE THE SORTING PROBLEM?")
# print(sorted_request_ranges)
# print(new_current_start_idx)
# print(new_fdb_node_ranges)
return (original_indices, sorted_request_ranges, new_fdb_node_ranges, new_current_start_idx)

def datacube_natural_indexes(self, axis, subarray):
indexes = subarray.get(axis.name, None)
Expand Down
2 changes: 1 addition & 1 deletion polytope/engine/hullslicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def find_compressed_axes(self, datacube, polytopes):
if compressed_axis in datacube.compressed_axes:
self.compressed_axes.append(compressed_axis)
# add the last axis of the grid always (longitude) as a compressed axis
# self.compressed_axes.append(datacube.coupled_axes[0][-1])
self.compressed_axes.append(datacube.coupled_axes[0][-1])

def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]):
# Determine list of axes to compress
Expand Down

0 comments on commit 73facc3

Please sign in to comment.