Skip to content

Commit 29f3a5c

Browse files
authored
Merge pull request matplotlib#29397 from scottshambaugh/masked_array_performance
3D plotting performance improvements
2 parents 743a005 + 4fc3745 commit 29f3a5c

File tree

6 files changed

+187
-92
lines changed

6 files changed

+187
-92
lines changed
+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
3D performance improvements
2+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
3+
4+
Draw time for 3D plots has been improved, especially for surface and wireframe
5+
plots. Users should see up to a 10x speedup in some cases. This should make
6+
interacting with 3D plots much more responsive.

lib/mpl_toolkits/mplot3d/art3d.py

+150-81
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def get_dir_vector(zdir):
7575

7676
def _viewlim_mask(xs, ys, zs, axes):
7777
"""
78-
Return original points with points outside the axes view limits masked.
78+
Return the mask of the points outside the axes view limits.
7979
8080
Parameters
8181
----------
@@ -86,19 +86,16 @@ def _viewlim_mask(xs, ys, zs, axes):
8686
8787
Returns
8888
-------
89-
xs_masked, ys_masked, zs_masked : np.ma.array
90-
The masked points.
89+
mask : np.array
90+
The mask of the points as a bool array.
9191
"""
9292
mask = np.logical_or.reduce((xs < axes.xy_viewLim.xmin,
9393
xs > axes.xy_viewLim.xmax,
9494
ys < axes.xy_viewLim.ymin,
9595
ys > axes.xy_viewLim.ymax,
9696
zs < axes.zz_viewLim.xmin,
9797
zs > axes.zz_viewLim.xmax))
98-
xs_masked = np.ma.array(xs, mask=mask)
99-
ys_masked = np.ma.array(ys, mask=mask)
100-
zs_masked = np.ma.array(zs, mask=mask)
101-
return xs_masked, ys_masked, zs_masked
98+
return mask
10299

103100

104101
class Text3D(mtext.Text):
@@ -182,14 +179,13 @@ def set_3d_properties(self, z=0, zdir='z', axlim_clip=False):
182179
@artist.allow_rasterization
183180
def draw(self, renderer):
184181
if self._axlim_clip:
185-
xs, ys, zs = _viewlim_mask(self._x, self._y, self._z, self.axes)
186-
position3d = np.ma.row_stack((xs, ys, zs)).ravel().filled(np.nan)
182+
mask = _viewlim_mask(self._x, self._y, self._z, self.axes)
183+
pos3d = np.ma.array([self._x, self._y, self._z],
184+
mask=mask, dtype=float).filled(np.nan)
187185
else:
188-
xs, ys, zs = self._x, self._y, self._z
189-
position3d = np.asanyarray([xs, ys, zs])
186+
pos3d = np.array([self._x, self._y, self._z], dtype=float)
190187

191-
proj = proj3d._proj_trans_points(
192-
[position3d, position3d + self._dir_vec], self.axes.M)
188+
proj = proj3d._proj_trans_points([pos3d, pos3d + self._dir_vec], self.axes.M)
193189
dx = proj[0][1] - proj[0][0]
194190
dy = proj[1][1] - proj[1][0]
195191
angle = math.degrees(math.atan2(dy, dx))
@@ -313,7 +309,12 @@ def get_data_3d(self):
313309
@artist.allow_rasterization
314310
def draw(self, renderer):
315311
if self._axlim_clip:
316-
xs3d, ys3d, zs3d = _viewlim_mask(*self._verts3d, self.axes)
312+
mask = np.broadcast_to(
313+
_viewlim_mask(*self._verts3d, self.axes),
314+
(len(self._verts3d), *self._verts3d[0].shape)
315+
)
316+
xs3d, ys3d, zs3d = np.ma.array(self._verts3d,
317+
dtype=float, mask=mask).filled(np.nan)
317318
else:
318319
xs3d, ys3d, zs3d = self._verts3d
319320
xs, ys, zs, tis = proj3d._proj_transform_clip(xs3d, ys3d, zs3d,
@@ -404,7 +405,8 @@ def do_3d_projection(self):
404405
"""Project the points according to renderer matrix."""
405406
vs_list = [vs for vs, _ in self._3dverts_codes]
406407
if self._axlim_clip:
407-
vs_list = [np.ma.row_stack(_viewlim_mask(*vs.T, self.axes)).T
408+
vs_list = [np.ma.array(vs, mask=np.broadcast_to(
409+
_viewlim_mask(*vs.T, self.axes), vs.shape))
408410
for vs in vs_list]
409411
xyzs_list = [proj3d.proj_transform(*vs.T, self.axes.M) for vs in vs_list]
410412
self._paths = [mpath.Path(np.ma.column_stack([xs, ys]), cs)
@@ -450,22 +452,32 @@ def do_3d_projection(self):
450452
"""
451453
Project the points according to renderer matrix.
452454
"""
453-
segments = self._segments3d
455+
segments = np.asanyarray(self._segments3d)
456+
457+
mask = False
458+
if np.ma.isMA(segments):
459+
mask = segments.mask
460+
454461
if self._axlim_clip:
455-
all_points = np.ma.vstack(segments)
456-
masked_points = np.ma.column_stack([*_viewlim_mask(*all_points.T,
457-
self.axes)])
458-
segment_lengths = [np.shape(segment)[0] for segment in segments]
459-
segments = np.split(masked_points, np.cumsum(segment_lengths[:-1]))
460-
xyslist = [proj3d._proj_trans_points(points, self.axes.M)
461-
for points in segments]
462-
segments_2d = [np.ma.column_stack([xs, ys]) for xs, ys, zs in xyslist]
462+
viewlim_mask = _viewlim_mask(segments[..., 0],
463+
segments[..., 1],
464+
segments[..., 2],
465+
self.axes)
466+
if np.any(viewlim_mask):
467+
# broadcast mask to 3D
468+
viewlim_mask = np.broadcast_to(viewlim_mask[..., np.newaxis],
469+
(*viewlim_mask.shape, 3))
470+
mask = mask | viewlim_mask
471+
xyzs = np.ma.array(proj3d._proj_transform_vectors(segments, self.axes.M),
472+
mask=mask)
473+
segments_2d = xyzs[..., 0:2]
463474
LineCollection.set_segments(self, segments_2d)
464475

465476
# FIXME
466-
minz = 1e9
467-
for xs, ys, zs in xyslist:
468-
minz = min(minz, min(zs))
477+
if len(xyzs) > 0:
478+
minz = min(xyzs[..., 2].min(), 1e9)
479+
else:
480+
minz = np.nan
469481
return minz
470482

471483

@@ -531,7 +543,9 @@ def get_path(self):
531543
def do_3d_projection(self):
532544
s = self._segment3d
533545
if self._axlim_clip:
534-
xs, ys, zs = _viewlim_mask(*zip(*s), self.axes)
546+
mask = _viewlim_mask(*zip(*s), self.axes)
547+
xs, ys, zs = np.ma.array(zip(*s),
548+
dtype=float, mask=mask).filled(np.nan)
535549
else:
536550
xs, ys, zs = zip(*s)
537551
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
@@ -587,7 +601,9 @@ def set_3d_properties(self, path, zs=0, zdir='z', axlim_clip=False):
587601
def do_3d_projection(self):
588602
s = self._segment3d
589603
if self._axlim_clip:
590-
xs, ys, zs = _viewlim_mask(*zip(*s), self.axes)
604+
mask = _viewlim_mask(*zip(*s), self.axes)
605+
xs, ys, zs = np.ma.array(zip(*s),
606+
dtype=float, mask=mask).filled(np.nan)
591607
else:
592608
xs, ys, zs = zip(*s)
593609
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
@@ -701,14 +717,18 @@ def set_3d_properties(self, zs, zdir, axlim_clip=False):
701717

702718
def do_3d_projection(self):
703719
if self._axlim_clip:
704-
xs, ys, zs = _viewlim_mask(*self._offsets3d, self.axes)
720+
mask = _viewlim_mask(*self._offsets3d, self.axes)
721+
xs, ys, zs = np.ma.array(self._offsets3d, mask=mask)
705722
else:
706723
xs, ys, zs = self._offsets3d
707724
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
708725
self.axes.M,
709726
self.axes._focal_length)
710727
self._vzs = vzs
711-
super().set_offsets(np.ma.column_stack([vxs, vys]))
728+
if np.ma.isMA(vxs):
729+
super().set_offsets(np.ma.column_stack([vxs, vys]))
730+
else:
731+
super().set_offsets(np.column_stack([vxs, vys]))
712732

713733
if vzs.size > 0:
714734
return min(vzs)
@@ -851,11 +871,18 @@ def set_depthshade(self, depthshade):
851871
self.stale = True
852872

853873
def do_3d_projection(self):
874+
mask = False
875+
for xyz in self._offsets3d:
876+
if np.ma.isMA(xyz):
877+
mask = mask | xyz.mask
854878
if self._axlim_clip:
855-
xs, ys, zs = _viewlim_mask(*self._offsets3d, self.axes)
879+
mask = mask | _viewlim_mask(*self._offsets3d, self.axes)
880+
mask = np.broadcast_to(mask,
881+
(len(self._offsets3d), *self._offsets3d[0].shape))
882+
xyzs = np.ma.array(self._offsets3d, mask=mask)
856883
else:
857-
xs, ys, zs = self._offsets3d
858-
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
884+
xyzs = self._offsets3d
885+
vxs, vys, vzs, vis = proj3d._proj_transform_clip(*xyzs,
859886
self.axes.M,
860887
self.axes._focal_length)
861888
# Sort the points based on z coordinates
@@ -1062,16 +1089,37 @@ def get_vector(self, segments3d):
10621089
return self._get_vector(segments3d)
10631090

10641091
def _get_vector(self, segments3d):
1065-
"""Optimize points for projection."""
1066-
if len(segments3d):
1067-
xs, ys, zs = np.vstack(segments3d).T
1068-
else: # vstack can't stack zero arrays.
1069-
xs, ys, zs = [], [], []
1070-
ones = np.ones(len(xs))
1071-
self._vec = np.array([xs, ys, zs, ones])
1092+
"""
1093+
Optimize points for projection.
10721094
1073-
indices = [0, *np.cumsum([len(segment) for segment in segments3d])]
1074-
self._segslices = [*map(slice, indices[:-1], indices[1:])]
1095+
Parameters
1096+
----------
1097+
segments3d : NumPy array or list of NumPy arrays
1098+
List of vertices of the boundary of every segment. If all paths are
1099+
of equal length and this argument is a NumPy array, then it should
1100+
be of shape (num_faces, num_vertices, 3).
1101+
"""
1102+
if isinstance(segments3d, np.ndarray):
1103+
if segments3d.ndim != 3 or segments3d.shape[-1] != 3:
1104+
raise ValueError("segments3d must be a MxNx3 array, but got "
1105+
f"shape {segments3d.shape}")
1106+
if isinstance(segments3d, np.ma.MaskedArray):
1107+
self._faces = segments3d.data
1108+
self._invalid_vertices = segments3d.mask.any(axis=-1)
1109+
else:
1110+
self._faces = segments3d
1111+
self._invalid_vertices = False
1112+
else:
1113+
# Turn the potentially ragged list into a numpy array for later speedups
1114+
# If it is ragged, set the unused vertices per face as invalid
1115+
num_faces = len(segments3d)
1116+
num_verts = np.fromiter(map(len, segments3d), dtype=np.intp)
1117+
max_verts = num_verts.max(initial=0)
1118+
segments = np.empty((num_faces, max_verts, 3))
1119+
for i, face in enumerate(segments3d):
1120+
segments[i, :len(face)] = face
1121+
self._faces = segments
1122+
self._invalid_vertices = np.arange(max_verts) >= num_verts[:, None]
10751123

10761124
def set_verts(self, verts, closed=True):
10771125
"""
@@ -1133,64 +1181,85 @@ def do_3d_projection(self):
11331181
self._facecolor3d = self._facecolors
11341182
if self._edge_is_mapped:
11351183
self._edgecolor3d = self._edgecolors
1184+
1185+
needs_masking = np.any(self._invalid_vertices)
1186+
num_faces = len(self._faces)
1187+
mask = self._invalid_vertices
1188+
1189+
# Some faces might contain masked vertices, so we want to ignore any
1190+
# errors that those might cause
1191+
with np.errstate(invalid='ignore', divide='ignore'):
1192+
pfaces = proj3d._proj_transform_vectors(self._faces, self.axes.M)
1193+
11361194
if self._axlim_clip:
1137-
xs, ys, zs = _viewlim_mask(*self._vec[0:3], self.axes)
1138-
if self._vec.shape[0] == 4: # Will be 3 (xyz) or 4 (xyzw)
1139-
w_masked = np.ma.masked_where(zs.mask, self._vec[3])
1140-
vec = np.ma.array([xs, ys, zs, w_masked])
1141-
else:
1142-
vec = np.ma.array([xs, ys, zs])
1143-
else:
1144-
vec = self._vec
1145-
txs, tys, tzs = proj3d._proj_transform_vec(vec, self.axes.M)
1146-
xyzlist = [(txs[sl], tys[sl], tzs[sl]) for sl in self._segslices]
1195+
viewlim_mask = _viewlim_mask(self._faces[..., 0], self._faces[..., 1],
1196+
self._faces[..., 2], self.axes)
1197+
if np.any(viewlim_mask):
1198+
needs_masking = True
1199+
mask = mask | viewlim_mask
1200+
1201+
pzs = pfaces[..., 2]
1202+
if needs_masking:
1203+
pzs = np.ma.MaskedArray(pzs, mask=mask)
11471204

11481205
# This extra fuss is to re-order face / edge colors
11491206
cface = self._facecolor3d
11501207
cedge = self._edgecolor3d
1151-
if len(cface) != len(xyzlist):
1152-
cface = cface.repeat(len(xyzlist), axis=0)
1153-
if len(cedge) != len(xyzlist):
1208+
if len(cface) != num_faces:
1209+
cface = cface.repeat(num_faces, axis=0)
1210+
if len(cedge) != num_faces:
11541211
if len(cedge) == 0:
11551212
cedge = cface
11561213
else:
1157-
cedge = cedge.repeat(len(xyzlist), axis=0)
1158-
1159-
if xyzlist:
1160-
# sort by depth (furthest drawn first)
1161-
z_segments_2d = sorted(
1162-
((self._zsortfunc(zs.data), np.ma.column_stack([xs, ys]), fc, ec, idx)
1163-
for idx, ((xs, ys, zs), fc, ec)
1164-
in enumerate(zip(xyzlist, cface, cedge))),
1165-
key=lambda x: x[0], reverse=True)
1166-
1167-
_, segments_2d, self._facecolors2d, self._edgecolors2d, idxs = \
1168-
zip(*z_segments_2d)
1169-
else:
1170-
segments_2d = []
1171-
self._facecolors2d = np.empty((0, 4))
1172-
self._edgecolors2d = np.empty((0, 4))
1173-
idxs = []
1174-
1175-
if self._codes3d is not None:
1176-
codes = [self._codes3d[idx] for idx in idxs]
1177-
PolyCollection.set_verts_and_codes(self, segments_2d, codes)
1214+
cedge = cedge.repeat(num_faces, axis=0)
1215+
1216+
if len(pzs) > 0:
1217+
face_z = self._zsortfunc(pzs, axis=-1)
11781218
else:
1179-
PolyCollection.set_verts(self, segments_2d, self._closed)
1219+
face_z = pzs
1220+
if needs_masking:
1221+
face_z = face_z.data
1222+
face_order = np.argsort(face_z, axis=-1)[::-1]
11801223

1181-
if len(self._edgecolor3d) != len(cface):
1224+
if len(pfaces) > 0:
1225+
faces_2d = pfaces[face_order, :, :2]
1226+
else:
1227+
faces_2d = pfaces
1228+
if self._codes3d is not None and len(self._codes3d) > 0:
1229+
if needs_masking:
1230+
segment_mask = ~mask[face_order, :]
1231+
faces_2d = [face[mask, :] for face, mask
1232+
in zip(faces_2d, segment_mask)]
1233+
codes = [self._codes3d[idx] for idx in face_order]
1234+
PolyCollection.set_verts_and_codes(self, faces_2d, codes)
1235+
else:
1236+
if needs_masking and len(faces_2d) > 0:
1237+
invalid_vertices_2d = np.broadcast_to(
1238+
mask[face_order, :, None],
1239+
faces_2d.shape)
1240+
faces_2d = np.ma.MaskedArray(
1241+
faces_2d, mask=invalid_vertices_2d)
1242+
PolyCollection.set_verts(self, faces_2d, self._closed)
1243+
1244+
if len(cface) > 0:
1245+
self._facecolors2d = cface[face_order]
1246+
else:
1247+
self._facecolors2d = cface
1248+
if len(self._edgecolor3d) == len(cface) and len(cedge) > 0:
1249+
self._edgecolors2d = cedge[face_order]
1250+
else:
11821251
self._edgecolors2d = self._edgecolor3d
11831252

11841253
# Return zorder value
11851254
if self._sort_zpos is not None:
11861255
zvec = np.array([[0], [0], [self._sort_zpos], [1]])
11871256
ztrans = proj3d._proj_transform_vec(zvec, self.axes.M)
11881257
return ztrans[2][0]
1189-
elif tzs.size > 0:
1258+
elif pzs.size > 0:
11901259
# FIXME: Some results still don't look quite right.
11911260
# In particular, examine contourf3d_demo2.py
11921261
# with az = -54 and elev = -45.
1193-
return np.min(tzs)
1262+
return np.min(pzs)
11941263
else:
11951264
return np.nan
11961265

lib/mpl_toolkits/mplot3d/axes3d.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2892,7 +2892,9 @@ def add_collection3d(self, col, zs=0, zdir='z', autolim=True, *,
28922892
self.auto_scale_xyz(*np.array(col._segments3d).transpose(),
28932893
had_data=had_data)
28942894
elif isinstance(col, art3d.Poly3DCollection):
2895-
self.auto_scale_xyz(*col._vec[:-1], had_data=had_data)
2895+
self.auto_scale_xyz(col._faces[..., 0],
2896+
col._faces[..., 1],
2897+
col._faces[..., 2], had_data=had_data)
28962898
elif isinstance(col, art3d.Patch3DCollection):
28972899
pass
28982900
# FIXME: Implement auto-scaling function for Patch3DCollection

0 commit comments

Comments
 (0)