diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 8db79fc3..03e4300e 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -65,7 +65,9 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install Dependencies - run: make install-deps + run: | + make install-deps + make install-submodules - name: Run all unit tests in JetStream (jetstream/tests) run: make unit-tests - name: Create test coverage report diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..b4480d50 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "jetstream/engine/implementations/maxtext"] + path = jetstream/engine/implementations/maxtext + url = https://github.com/google/maxtext.git diff --git a/Makefile b/Makefile index a8a88085..fb4f101f 100644 --- a/Makefile +++ b/Makefile @@ -5,9 +5,15 @@ GRPC_TOOLS_VERSION := 1.62.1 all: install-deps generate-protos format check # Dependency management targets + install-deps: $(PIP) install pytype pylint pyink -r requirements.txt -r benchmarks/requirements.in +install-submodules: + git submodule update --init --recursive + - $(PIP) install -r ./jetstream/engine/implementations/maxtext/requirements.txt + - $(PIP) install jetstream_pt@git+https://github.com/google/jetstream-pytorch.git@jetstream-v0.2.2#egg=jetstream_pt + # Code generation/formatting targets generate-protos: generate-and-prepend-preambles format diff --git a/README.md b/README.md index 62959c46..7fa1d65a 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ make install-deps Use the following commands to run a server locally: ``` # Start a server -python -m jetstream.core.implementations.mock.server +python -m jetstream.entrypoints.mock.server # Test local mock server python -m jetstream.tools.requester diff --git a/jetstream/engine/__init__.py b/jetstream/engine/__init__.py index ee979964..4101fcd1 100644 --- a/jetstream/engine/__init__.py +++ b/jetstream/engine/__init__.py @@ -21,3 +21,11 @@ except ImportError as e: print("Proxy backend support is not added") pass + +import os +import sys + +submodule_path = os.path.join( + os.path.dirname(__file__), "implementations/maxtext/MaxText" +) +sys.path.append(submodule_path) diff --git a/jetstream/engine/implementations/maxtext b/jetstream/engine/implementations/maxtext new file mode 160000 index 00000000..2a6154f2 --- /dev/null +++ b/jetstream/engine/implementations/maxtext @@ -0,0 +1 @@ +Subproject commit 2a6154f254bf5dbe67e659360775a83a797ed7f9 diff --git a/jetstream/engine/mock_engine.py b/jetstream/engine/mock_engine.py index 0277e9a3..10c06e6d 100644 --- a/jetstream/engine/mock_engine.py +++ b/jetstream/engine/mock_engine.py @@ -129,7 +129,7 @@ def prefill( samples_per_slot=self.generate_cache_batch // self.prefill_cache_batch, ) - return (prefill_cache, first_step), first_token + return (prefill_cache.astype(jnp.float32), first_step), first_token @functools.partial(jax.jit, static_argnums=(0,)) def generate( @@ -152,7 +152,7 @@ def generate( # Update generate cache generate_cache = jax.lax.dynamic_update_slice_in_dim( - generate_cache, + generate_cache.astype(jnp.float32), previous_timestep, start_index=generate_cache_index, axis=1, @@ -198,7 +198,7 @@ def generate( ) return DecodeState( prefill_cache=prefill_cache, - generate_cache=generate_cache, + generate_cache=generate_cache.astype(jnp.float32), generate_cache_index=generate_cache_index, generate_lengths=new_lengths, generate_tokens=new_timestep, @@ -230,7 +230,7 @@ def insert( ) generate_cache = jax.lax.dynamic_update_slice_in_dim( decode_state.generate_cache, - jnp.zeros((1, self.cache_length)), + jnp.zeros((1, self.cache_length), dtype=jnp.float32), slot, axis=0, ) @@ -243,7 +243,7 @@ def insert( ) generate_tokens = jax.lax.dynamic_update_slice_in_dim( decode_state.generate_tokens, - previous_timestep, + previous_timestep.astype(jnp.float32), slot * samples_per_slot, axis=0, ) diff --git a/jetstream/entrypoints/config.py b/jetstream/entrypoints/config.py index 79f2b012..644cc38f 100644 --- a/jetstream/entrypoints/config.py +++ b/jetstream/entrypoints/config.py @@ -14,15 +14,49 @@ """Config for JetStream Server (including engine init).""" -from typing import Type +import functools +import os +from typing import Sequence, Type +import jax from jetstream.core import config_lib +from jetstream.engine.implementations.maxtext.MaxText import maxengine, pyconfig +from jetstream_pt import config def get_server_config( - config_str: str, + config_str: str, argv: Sequence[str] ) -> config_lib.ServerConfig | Type[config_lib.ServerConfig]: match config_str: + case "MaxtextInterleavedServer": + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + pyconfig.initialize(argv) + server_config = config_lib.ServerConfig( + prefill_slices=(), + generate_slices=(), + interleaved_slices=("tpu=" + str(jax.device_count()),), + prefill_engine_create_fns=(), + generate_engine_create_fns=(), + interleaved_engine_create_fns=( + functools.partial( + maxengine.MaxEngine(config), config=pyconfig.config + ), + ), + ) + case "PyTorchInterleavedServer": + os.environ["XLA_FLAGS"] = ( + "--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text" + ) + engine = config.create_engine_from_config_flags() + server_config = config_lib.ServerConfig( + prefill_slices=(), + generate_slices=(), + interleaved_slices=("tpu=" + str(jax.device_count()),), + prefill_engine_create_fns=(), + generate_engine_create_fns=(), + interleaved_engine_create_fns=(lambda a: engine,), + ) case "InterleavedCPUTestServer": server_config = config_lib.InterleavedCPUTestServer case "CPUTestServer": diff --git a/jetstream/core/implementations/__init__.py b/jetstream/entrypoints/grpc/__init__.py similarity index 100% rename from jetstream/core/implementations/__init__.py rename to jetstream/entrypoints/grpc/__init__.py diff --git a/jetstream/entrypoints/grpc/server.py b/jetstream/entrypoints/grpc/server.py new file mode 100644 index 00000000..e54fb767 --- /dev/null +++ b/jetstream/entrypoints/grpc/server.py @@ -0,0 +1,62 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runs a JetStream Server.""" + +from typing import Sequence + +from absl import app +from absl import flags + +from jetstream.entrypoints import config +from jetstream.core import config_lib, server_lib + + +flags.DEFINE_integer("port", 9000, "port to listen on") +flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool") +flags.DEFINE_string( + "config", + "InterleavedCPUTestServer", + "available servers", +) +flags.DEFINE_integer("prometheus_port", 0, "") + + +def main(argv: Sequence[str]): + devices = server_lib.get_devices() + print(f"devices: {devices}") + server_config = config.get_server_config(flags.FLAGS.config, argv) + print(f"server_config: {server_config}") + del argv + + metrics_server_config: config_lib.MetricsServerConfig | None = None + if flags.FLAGS.prometheus_port != 0: + metrics_server_config = config_lib.MetricsServerConfig( + port=flags.FLAGS.prometheus_port + ) + # We separate credential from run so that we can unit test it with local + # credentials. + # TODO: Add grpc credentials for OSS. + jetstream_server = server_lib.run( + threads=flags.FLAGS.threads, + port=flags.FLAGS.port, + config=server_config, + devices=devices, + metrics_server_config=metrics_server_config, + ) + jetstream_server.wait_for_termination() + + +if __name__ == "__main__": + app.run(main) diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py index aaced235..3879b435 100644 --- a/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -94,11 +94,11 @@ def server(argv: Sequence[str]): app.include_router(router) # Init LLMOrchestrator which would be the main handler in the api endpoints. - devices = server_lib.get_devices() - print(f"devices: {devices}") - server_config = get_server_config(flags.FLAGS.config) + server_config = get_server_config(flags.FLAGS.config, argv) print(f"server_config: {server_config}") del argv + devices = server_lib.get_devices() + print(f"devices: {devices}") metrics_server_config: config_lib.MetricsServerConfig | None = None # Setup Prometheus server diff --git a/jetstream/core/implementations/mock/README.md b/jetstream/entrypoints/mock/README.md similarity index 100% rename from jetstream/core/implementations/mock/README.md rename to jetstream/entrypoints/mock/README.md diff --git a/jetstream/core/implementations/mock/__init__.py b/jetstream/entrypoints/mock/__init__.py similarity index 100% rename from jetstream/core/implementations/mock/__init__.py rename to jetstream/entrypoints/mock/__init__.py diff --git a/jetstream/core/implementations/mock/config.py b/jetstream/entrypoints/mock/config.py similarity index 100% rename from jetstream/core/implementations/mock/config.py rename to jetstream/entrypoints/mock/config.py diff --git a/jetstream/core/implementations/mock/server.py b/jetstream/entrypoints/mock/server.py similarity index 95% rename from jetstream/core/implementations/mock/server.py rename to jetstream/entrypoints/mock/server.py index 6a0cee76..aca0c427 100644 --- a/jetstream/core/implementations/mock/server.py +++ b/jetstream/entrypoints/mock/server.py @@ -19,7 +19,7 @@ from absl import app from absl import flags -from jetstream.core.implementations.mock import config as mock_config +from jetstream.entrypoints.mock import config as mock_config from jetstream.core import server_lib diff --git a/jetstream/tests/entrypoints/http/test_api_server.py b/jetstream/tests/entrypoints/http/test_api_server.py index e6d42e58..eeebe694 100644 --- a/jetstream/tests/entrypoints/http/test_api_server.py +++ b/jetstream/tests/entrypoints/http/test_api_server.py @@ -14,6 +14,7 @@ """Tests http server end-to-end.""" +import os import subprocess import sys import time @@ -29,6 +30,9 @@ class HTTPServerTest(unittest.IsolatedAsyncioTestCase): def setUpClass(cls): """Sets up a JetStream http server for unit tests.""" cls.base_url = "http://localhost:8080" + my_env = os.environ.copy() # Create a copy of the current environment + my_env["JAX_PLATFORMS"] = "cpu" + my_env["JAX_TRACEBACK_FILTERING"] = "off" cls.server = subprocess.Popen( [ "python", @@ -36,6 +40,7 @@ def setUpClass(cls): "jetstream.entrypoints.http.api_server", "--config=InterleavedCPUTestServer", ], + env=my_env, stdout=sys.stdout, stderr=sys.stderr, ) diff --git a/requirements-standalone.txt b/requirements-standalone.txt new file mode 100644 index 00000000..716e7ba2 --- /dev/null +++ b/requirements-standalone.txt @@ -0,0 +1,27 @@ +# jetstream library +absl-py +coverage +flax +grpcio +jax +jaxlib +numpy +portpicker +prometheus-client +pytest +seqio +tiktoken +blobfile +parameterized +shortuuid +# jetstream benchmarks +nltk +evaluate +rouge-score +tqdm +# jetstream profiling +tensorboard-plugin-profile +# engines +# maxtext @ git+https://github.com/google/maxtext.git@jetstream-v0.2.2#egg=maxtext +# maxtext @ {root:uri}/jetstream/engine/implementations/maxtext +jetstream_pt @ git+https://github.com/google/jetstream-pytorch.git@jetstream-v0.2.2#egg=jetstream_pt