Skip to content

Commit 083ca6c

Browse files
committedMar 1, 2025
feat: allow merging multiple simpleapi indexes
Before this change users would have no way to download pytorch from multiple indexes (e.g. `cpu` version on some platforms and `gpu` on another) and use it inside a single hub repository. This brings the users necessary toggles to tell `rules_python` to search in multiple indexes for the `torch` wheels. Note, that the `index_strategy` field is going to be used only if one is setting multiple indexes via the `extra_index_urls` and not via the `index_url_overrides`. Whilst at it I have improved the `simpleapi_download` tests to also test the warning messages that we may print to the user. Fixes bazel-contrib#2622
1 parent bb6249b commit 083ca6c

File tree

3 files changed

+326
-44
lines changed

3 files changed

+326
-44
lines changed
 

‎python/private/pypi/extension.bzl

+31-1
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ You cannot use both the additive_build_content and additive_build_content_file a
472472
index_url = pip_attr.experimental_index_url,
473473
extra_index_urls = pip_attr.experimental_extra_index_urls or [],
474474
index_url_overrides = pip_attr.experimental_index_url_overrides or {},
475+
index_strategy = pip_attr.index_strategy,
475476
sources = distributions,
476477
envsubst = pip_attr.envsubst,
477478
# Auth related info
@@ -681,27 +682,41 @@ stable.
681682
682683
This is equivalent to `--index-url` `pip` option.
683684
685+
:::{warn}
686+
`rules_python` will fallback to using `pip` to download wheels if the requirements
687+
files do not have hashes.
688+
:::
689+
684690
:::{versionchanged} 0.37.0
685691
If {attr}`download_only` is set, then `sdist` archives will be discarded and `pip.parse` will
686692
operate in wheel-only mode.
687693
:::
688694
""",
689695
),
690696
"experimental_index_url_overrides": attr.string_dict(
697+
# TODO @aignas 2025-03-01: consider using string_list_dict so that
698+
# we could have index_url_overrides per package for different
699+
# platforms like what `uv` has.
700+
# See https://docs.astral.sh/uv/configuration/indexes/#-index-url-and-extra-index-url
691701
doc = """\
692702
The index URL overrides for each package to use for downloading wheels using
693703
bazel downloader. This value is going to be subject to `envsubst` substitutions
694704
if necessary.
695705
696706
The key is the package name (will be normalized before usage) and the value is the
697-
index URL.
707+
index URLs separated with `,`.
698708
699709
This design pattern has been chosen in order to be fully deterministic about which
700710
packages come from which source. We want to avoid issues similar to what happened in
701711
https://pytorch.org/blog/compromised-nightly-dependency/.
702712
703713
The indexes must support Simple API as described here:
704714
https://packaging.python.org/en/latest/specifications/simple-repository-api/
715+
716+
:::{versionchanged} VERSION_NEXT_PATCH
717+
This can contain comma separated values per package to allow `torch` being
718+
indexed from multiple sources.
719+
:::
705720
""",
706721
),
707722
"hub_name": attr.string(
@@ -724,6 +739,21 @@ is not required. Each hub is a separate resolution of pip dependencies. This
724739
means if different programs need different versions of some library, separate
725740
hubs can be created, and each program can use its respective hub's targets.
726741
Targets from different hubs should not be used together.
742+
""",
743+
),
744+
"index_strategy": attr.string(
745+
default = "first-index",
746+
values = ["first-index", "unsafe"],
747+
doc = """\
748+
The strategy used when fetching package locations from indexes. This is to allow fetching
749+
`torch` from the `torch` maintained and PyPI index so that on different platforms users
750+
can have different torch versions (e.g. gpu accelerated on linux and cpu on the
751+
rest of the platforms).
752+
753+
See https://docs.astral.sh/uv/configuration/indexes/#searching-across-multiple-indexes.
754+
755+
:::{versionadded} VERSION_NEXT_PATCH
756+
:::
727757
""",
728758
),
729759
"parallel_download": attr.bool(

‎python/private/pypi/simpleapi_download.bzl

+82-35
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def simpleapi_download(
3131
parallel_download = True,
3232
read_simpleapi = None,
3333
get_auth = None,
34+
_print = print,
3435
_fail = fail):
3536
"""Download Simple API HTML.
3637
@@ -43,6 +44,8 @@ def simpleapi_download(
4344
separate packages.
4445
* extra_index_urls: Extra index URLs that will be looked up after
4546
the main is looked up.
47+
* index_strategy: The string identifier representing the strategy
48+
used here. Can be either "first-index" or "unsafe".
4649
* sources: list[str], the sources to download things for. Each value is
4750
the contents of requirements files.
4851
* envsubst: list[str], the envsubst vars for performing substitution in index url.
@@ -61,6 +64,7 @@ def simpleapi_download(
6164
read_simpleapi: a function for reading and parsing of the SimpleAPI contents.
6265
Used in tests.
6366
get_auth: A function to get auth information passed to read_simpleapi. Used in tests.
67+
_print: a function to print. Used in tests.
6468
_fail: a function to print a failure. Used in tests.
6569
6670
Returns:
@@ -71,6 +75,9 @@ def simpleapi_download(
7175
for p, i in (attr.index_url_overrides or {}).items()
7276
}
7377

78+
if attr.index_strategy not in ["unsafe", "first-index"]:
79+
fail("TODO")
80+
7481
download_kwargs = {}
7582
if bazel_features.external_deps.download_has_block_param:
7683
download_kwargs["block"] = not parallel_download
@@ -80,68 +87,108 @@ def simpleapi_download(
8087
contents = {}
8188
index_urls = [attr.index_url] + attr.extra_index_urls
8289
read_simpleapi = read_simpleapi or _read_simpleapi
90+
sources = {
91+
pkg: normalize_name(pkg)
92+
for pkg in attr.sources
93+
}
8394

84-
found_on_index = {}
95+
found_on_indexes = {}
8596
warn_overrides = False
8697
for i, index_url in enumerate(index_urls):
8798
if i != 0:
8899
# Warn the user about a potential fix for the overrides
89100
warn_overrides = True
90101

91102
async_downloads = {}
92-
sources = [pkg for pkg in attr.sources if pkg not in found_on_index]
93-
for pkg in sources:
103+
for pkg, pkg_normalized in sources.items():
104+
if pkg not in found_on_indexes:
105+
# We have not found the pkg yet, let's search for it
106+
pass
107+
elif "first-index" == attr.index_strategy and pkg in found_on_indexes:
108+
# We have found it and we are using a safe strategy, let's not
109+
# search anymore.
110+
continue
111+
elif pkg in found_on_indexes and pkg_normalized in index_url_overrides:
112+
# This pkg has been overriden, be strict and use `first-index` strategy
113+
# implicitly.
114+
continue
115+
elif "unsafe" in attr.index_strategy:
116+
# We can search for the packages
117+
pass
118+
else:
119+
fail("BUG: Unknown state of searching of packages")
120+
94121
pkg_normalized = normalize_name(pkg)
95-
result = read_simpleapi(
96-
ctx = ctx,
97-
url = "{}/{}/".format(
98-
index_url_overrides.get(pkg_normalized, index_url).rstrip("/"),
99-
pkg,
100-
),
101-
attr = attr,
102-
cache = cache,
103-
get_auth = get_auth,
104-
**download_kwargs
105-
)
106-
if hasattr(result, "wait"):
107-
# We will process it in a separate loop:
108-
async_downloads[pkg] = struct(
109-
pkg_normalized = pkg_normalized,
110-
wait = result.wait,
122+
override_urls = index_url_overrides.get(pkg_normalized, index_url)
123+
for url in override_urls.split(","):
124+
result = read_simpleapi(
125+
ctx = ctx,
126+
url = "{}/{}/".format(
127+
url.rstrip("/"),
128+
pkg,
129+
),
130+
attr = attr,
131+
cache = cache,
132+
get_auth = get_auth,
133+
**download_kwargs
111134
)
112-
elif result.success:
113-
contents[pkg_normalized] = result.output
114-
found_on_index[pkg] = index_url
135+
if hasattr(result, "wait"):
136+
# We will process it in a separate loop:
137+
async_downloads.setdefault(pkg, []).append(
138+
struct(
139+
pkg_normalized = pkg_normalized,
140+
wait = result.wait,
141+
),
142+
)
143+
elif result.success:
144+
current = contents.get(
145+
pkg_normalized,
146+
struct(sdists = {}, whls = {}),
147+
)
148+
contents[pkg_normalized] = struct(
149+
# Always prefer the current values, so that the first index wins
150+
sdists = result.output.sdists | current.sdists,
151+
whls = result.output.whls | current.whls,
152+
)
153+
found_on_indexes.setdefault(pkg, []).append(url)
115154

116155
if not async_downloads:
117156
continue
118157

119158
# If we use `block` == False, then we need to have a second loop that is
120159
# collecting all of the results as they were being downloaded in parallel.
121-
for pkg, download in async_downloads.items():
122-
result = download.wait()
123-
124-
if result.success:
125-
contents[download.pkg_normalized] = result.output
126-
found_on_index[pkg] = index_url
127-
128-
failed_sources = [pkg for pkg in attr.sources if pkg not in found_on_index]
160+
for pkg, downloads in async_downloads.items():
161+
for download in downloads:
162+
result = download.wait()
163+
164+
if result.success:
165+
current = contents.get(
166+
download.pkg_normalized,
167+
struct(sdists = {}, whls = {}),
168+
)
169+
contents[download.pkg_normalized] = struct(
170+
# Always prefer the current values, so that the first index wins
171+
sdists = result.output.sdists | current.sdists,
172+
whls = result.output.whls | current.whls,
173+
)
174+
found_on_indexes.setdefault(pkg, []).append(index_url)
175+
176+
failed_sources = [pkg for pkg in attr.sources if pkg not in found_on_indexes]
129177
if failed_sources:
130-
_fail("Failed to download metadata for {} for from urls: {}".format(
178+
_fail("Failed to download metadata for {} from urls: {}".format(
131179
failed_sources,
132180
index_urls,
133181
))
134182
return None
135183

136184
if warn_overrides:
137185
index_url_overrides = {
138-
pkg: found_on_index[pkg]
186+
pkg: ",".join(found_on_indexes[pkg])
139187
for pkg in attr.sources
140-
if found_on_index[pkg] != attr.index_url
188+
if found_on_indexes[pkg] != attr.index_url
141189
}
142190

143-
# buildifier: disable=print
144-
print("You can use the following `index_url_overrides` to avoid the 404 warnings:\n{}".format(
191+
_print("You can use the following `index_url_overrides` to avoid the 404 warnings:\n{}".format(
145192
render.dict(index_url_overrides),
146193
))
147194

‎tests/pypi/simpleapi_download/simpleapi_download_tests.bzl

+213-8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ _tests = []
2121

2222
def _test_simple(env):
2323
calls = []
24+
warnings_suggestion = []
2425

2526
def read_simpleapi(ctx, url, attr, cache, get_auth, block):
2627
_ = ctx # buildifier: disable=unused-variable
@@ -31,12 +32,18 @@ def _test_simple(env):
3132
calls.append(url)
3233
if "foo" in url and "main" in url:
3334
return struct(
34-
output = "",
35+
output = struct(
36+
sdists = {"": ""},
37+
whls = {},
38+
),
3539
success = False,
3640
)
3741
else:
3842
return struct(
39-
output = "data from {}".format(url),
43+
output = struct(
44+
sdists = {"": "data from {}".format(url)},
45+
whls = {},
46+
),
4047
success = True,
4148
)
4249

@@ -48,12 +55,14 @@ def _test_simple(env):
4855
index_url_overrides = {},
4956
index_url = "main",
5057
extra_index_urls = ["extra"],
58+
index_strategy = "first-index",
5159
sources = ["foo", "bar", "baz"],
5260
envsubst = [],
5361
),
5462
cache = {},
5563
parallel_download = True,
5664
read_simpleapi = read_simpleapi,
65+
_print = warnings_suggestion.append,
5766
)
5867

5968
env.expect.that_collection(calls).contains_exactly([
@@ -63,13 +72,195 @@ def _test_simple(env):
6372
"main/foo/",
6473
])
6574
env.expect.that_dict(contents).contains_exactly({
66-
"bar": "data from main/bar/",
67-
"baz": "data from main/baz/",
68-
"foo": "data from extra/foo/",
75+
"bar": struct(
76+
sdists = {"": "data from main/bar/"},
77+
whls = {},
78+
),
79+
"baz": struct(
80+
sdists = {"": "data from main/baz/"},
81+
whls = {},
82+
),
83+
"foo": struct(
84+
sdists = {"": "data from extra/foo/"},
85+
whls = {},
86+
),
6987
})
88+
env.expect.that_collection(warnings_suggestion).contains_exactly([
89+
"""\
90+
You can use the following `index_url_overrides` to avoid the 404 warnings:
91+
{
92+
"foo": "extra",
93+
"bar": "main",
94+
"baz": "main",
95+
}""",
96+
])
7097

7198
_tests.append(_test_simple)
7299

100+
def _test_overrides_and_precedence(env):
101+
calls = []
102+
103+
def read_simpleapi(ctx, url, attr, cache, get_auth, block):
104+
_ = ctx # buildifier: disable=unused-variable
105+
_ = attr
106+
_ = cache
107+
_ = get_auth
108+
env.expect.that_bool(block).equals(False)
109+
calls.append(url)
110+
if "foo" in url and "main" in url:
111+
return struct(
112+
output = struct(
113+
sdists = {"": ""},
114+
whls = {},
115+
),
116+
# This will ensure that we fail the test if we go into this
117+
# branch unexpectedly.
118+
success = False,
119+
)
120+
else:
121+
return struct(
122+
output = struct(
123+
sdists = {"": "data from {}".format(url)},
124+
whls = {
125+
url: "whl from {}".format(url),
126+
} if "foo" in url else {},
127+
),
128+
success = True,
129+
)
130+
131+
contents = simpleapi_download(
132+
ctx = struct(
133+
os = struct(environ = {}),
134+
),
135+
attr = struct(
136+
index_url_overrides = {
137+
"foo": "extra1,extra2",
138+
},
139+
index_url = "main",
140+
extra_index_urls = [],
141+
# If we pass overrides, then we will get packages from all indexes.
142+
# However, for packages without index_url_overrides, we will honour
143+
# the strategy setting.
144+
index_strategy = "first-index",
145+
sources = ["foo", "bar", "baz"],
146+
envsubst = [],
147+
),
148+
cache = {},
149+
parallel_download = True,
150+
read_simpleapi = read_simpleapi,
151+
_print = fail,
152+
)
153+
154+
env.expect.that_collection(calls).contains_exactly([
155+
"extra1/foo/",
156+
"extra2/foo/",
157+
"main/bar/",
158+
"main/baz/",
159+
])
160+
env.expect.that_dict(contents).contains_exactly({
161+
"bar": struct(
162+
sdists = {"": "data from main/bar/"},
163+
whls = {},
164+
),
165+
"baz": struct(
166+
sdists = {"": "data from main/baz/"},
167+
whls = {},
168+
),
169+
"foo": struct(
170+
# We prioritize the first index
171+
sdists = {"": "data from extra1/foo/"},
172+
whls = {
173+
"extra1/foo/": "whl from extra1/foo/",
174+
"extra2/foo/": "whl from extra2/foo/",
175+
},
176+
),
177+
})
178+
179+
_tests.append(_test_overrides_and_precedence)
180+
181+
def _test_unsafe_strategy(env):
182+
calls = []
183+
warnings_suggestion = []
184+
185+
def read_simpleapi(ctx, url, attr, cache, get_auth, block):
186+
_ = ctx # buildifier: disable=unused-variable
187+
_ = attr
188+
_ = cache
189+
_ = get_auth
190+
env.expect.that_bool(block).equals(False)
191+
calls.append(url)
192+
return struct(
193+
output = struct(
194+
sdists = {"": "data from {}".format(url)},
195+
whls = {
196+
url: "whl from {}".format(url),
197+
} if "foo" in url else {},
198+
),
199+
success = True,
200+
)
201+
202+
contents = simpleapi_download(
203+
ctx = struct(
204+
os = struct(environ = {}),
205+
),
206+
attr = struct(
207+
index_url_overrides = {
208+
"foo": "extra1,extra2",
209+
},
210+
index_url = "main",
211+
# This field would be ignored for others
212+
extra_index_urls = ["extra"],
213+
# If we pass overrides, then we will get packages from all indexes.
214+
# However, for packages without index_url_overrides, we will honour
215+
# the strategy setting.
216+
index_strategy = "unsafe",
217+
sources = ["foo", "bar", "baz"],
218+
envsubst = [],
219+
),
220+
cache = {},
221+
parallel_download = True,
222+
read_simpleapi = read_simpleapi,
223+
_print = warnings_suggestion.append,
224+
)
225+
226+
env.expect.that_collection(calls).contains_exactly([
227+
"extra1/foo/",
228+
"extra2/foo/",
229+
"main/bar/",
230+
"main/baz/",
231+
"extra/bar/",
232+
"extra/baz/",
233+
])
234+
env.expect.that_dict(contents).contains_exactly({
235+
"bar": struct(
236+
sdists = {"": "data from main/bar/"},
237+
whls = {},
238+
),
239+
"baz": struct(
240+
sdists = {"": "data from main/baz/"},
241+
whls = {},
242+
),
243+
"foo": struct(
244+
# We prioritize the first index
245+
sdists = {"": "data from extra1/foo/"},
246+
whls = {
247+
"extra1/foo/": "whl from extra1/foo/",
248+
"extra2/foo/": "whl from extra2/foo/",
249+
},
250+
),
251+
})
252+
env.expect.that_collection(warnings_suggestion).contains_exactly([
253+
"""\
254+
You can use the following `index_url_overrides` to avoid the 404 warnings:
255+
{
256+
"foo": "extra1,extra2",
257+
"bar": "main,extra",
258+
"baz": "main,extra",
259+
}""",
260+
])
261+
262+
_tests.append(_test_unsafe_strategy)
263+
73264
def _test_fail(env):
74265
calls = []
75266
fails = []
@@ -83,12 +274,18 @@ def _test_fail(env):
83274
calls.append(url)
84275
if "foo" in url:
85276
return struct(
86-
output = "",
277+
output = struct(
278+
sdists = {"": ""},
279+
whls = {},
280+
),
87281
success = False,
88282
)
89283
else:
90284
return struct(
91-
output = "data from {}".format(url),
285+
output = struct(
286+
sdists = {"": "data from {}".format(url)},
287+
whls = {},
288+
),
92289
success = True,
93290
)
94291

@@ -100,17 +297,19 @@ def _test_fail(env):
100297
index_url_overrides = {},
101298
index_url = "main",
102299
extra_index_urls = ["extra"],
300+
index_strategy = "first-index",
103301
sources = ["foo", "bar", "baz"],
104302
envsubst = [],
105303
),
106304
cache = {},
107305
parallel_download = True,
108306
read_simpleapi = read_simpleapi,
109307
_fail = fails.append,
308+
_print = fail,
110309
)
111310

112311
env.expect.that_collection(fails).contains_exactly([
113-
"""Failed to download metadata for ["foo"] for from urls: ["main", "extra"]""",
312+
"""Failed to download metadata for ["foo"] from urls: ["main", "extra"]""",
114313
])
115314
env.expect.that_collection(calls).contains_exactly([
116315
"extra/foo/",
@@ -140,12 +339,14 @@ def _test_download_url(env):
140339
index_url_overrides = {},
141340
index_url = "https://example.com/main/simple/",
142341
extra_index_urls = [],
342+
index_strategy = "first-index",
143343
sources = ["foo", "bar", "baz"],
144344
envsubst = [],
145345
),
146346
cache = {},
147347
parallel_download = False,
148348
get_auth = lambda ctx, urls, ctx_attr: struct(),
349+
_print = fail,
149350
)
150351

151352
env.expect.that_dict(downloads).contains_exactly({
@@ -175,12 +376,14 @@ def _test_download_url_parallel(env):
175376
index_url_overrides = {},
176377
index_url = "https://example.com/main/simple/",
177378
extra_index_urls = [],
379+
index_strategy = "first-index",
178380
sources = ["foo", "bar", "baz"],
179381
envsubst = [],
180382
),
181383
cache = {},
182384
parallel_download = True,
183385
get_auth = lambda ctx, urls, ctx_attr: struct(),
386+
_print = fail,
184387
)
185388

186389
env.expect.that_dict(downloads).contains_exactly({
@@ -210,12 +413,14 @@ def _test_download_envsubst_url(env):
210413
index_url_overrides = {},
211414
index_url = "$INDEX_URL",
212415
extra_index_urls = [],
416+
index_strategy = "first-index",
213417
sources = ["foo", "bar", "baz"],
214418
envsubst = ["INDEX_URL"],
215419
),
216420
cache = {},
217421
parallel_download = False,
218422
get_auth = lambda ctx, urls, ctx_attr: struct(),
423+
_print = fail,
219424
)
220425

221426
env.expect.that_dict(downloads).contains_exactly({

0 commit comments

Comments
 (0)
Please sign in to comment.