diff --git a/numba/compiler.py b/numba/compiler.py index 6370da14d22..f5bdaa263c4 100644 --- a/numba/compiler.py +++ b/numba/compiler.py @@ -441,9 +441,9 @@ def define_nopython_pipeline(state, name='nopython'): # pre typing if not state.flags.no_rewrites: - pm.add_pass(GenericRewrites, "nopython rewrites") pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants") pm.add_pass(DeadBranchPrune, "dead branch pruning") + pm.add_pass(GenericRewrites, "nopython rewrites") pm.add_pass(InlineClosureLikes, "inline calls to locally defined closures") diff --git a/numba/tests/test_analysis.py b/numba/tests/test_analysis.py index 6be83321bd4..965a1aa6395 100644 --- a/numba/tests/test_analysis.py +++ b/numba/tests/test_analysis.py @@ -633,3 +633,14 @@ def impl(fa): FakeArrayType = types.NamedUniTuple(types.int64, 1, FakeArray) self.assert_prune(impl, (FakeArrayType,), [None], fa, flags=enable_pyobj_flags) + + def test_semantic_const_propagates_before_static_rewrites(self): + # see issue #5015, the ndim needs writing in as a const before + # the rewrite passes run to make e.g. getitems static where possible + @njit + def impl(a, b): + return a.shape[:b.ndim] + + args = (np.zeros((5, 4, 3, 2)), np.zeros((1, 1))) + + self.assertPreciseEqual(impl(*args), impl.py_func(*args))