Skip to content

Commit 75b6210

Browse files
authored
Kaiyu/update main (NVIDIA#5)
* Update * Update
1 parent 4941ad2 commit 75b6210

File tree

421 files changed

+1905283
-1537434
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

421 files changed

+1905283
-1537434
lines changed

README.md

+209-231
Large diffs are not rendered by default.

benchmarks/cpp/README.md

+5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ multiple GPUs or multiple nodes with multiple GPUs.
99

1010
Please follow the [`installation document`](../../../README.md) to build TensorRT-LLM.
1111

12+
Windows users: Follow the
13+
[`Windows installation document`](../../../windows/README.md)
14+
instead, and be sure to set DLL paths as specified in
15+
[Extra Steps for C++ Runtime Usage](../../../windows/README.md#extra-steps-for-c-runtime-usage).
16+
1217
After that, you can build benchmarking source code for C++ runtime
1318
```
1419
cd cpp/build

benchmarks/cpp/gptManagerBenchmark.cpp

+17-9
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ class GptServer
275275
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
276276
batch_scheduler::SchedulerPolicy schedulerPolicy, std::optional<int32_t> maxNumSequences,
277277
std::optional<int32_t> maxTokensInPagedKvCache, std::optional<float> kvCacheFreeGpuMemFraction,
278-
std::optional<bool> enableTrtOverlap, std::shared_ptr<Recorder> recorder)
278+
std::optional<bool> enableTrtOverlap, std::shared_ptr<Recorder> recorder,
279+
std::optional<uint64_t> terminateReqId)
279280
{
280281
const TrtGptModelOptionalParams& optionalParams = TrtGptModelOptionalParams(
281282
maxNumSequences, maxTokensInPagedKvCache, kvCacheFreeGpuMemFraction, enableTrtOverlap);
@@ -285,8 +286,9 @@ class GptServer
285286
[this](uint64_t requestId, std::list<NamedTensor> response_tensors, bool final_response,
286287
const std::string& errMsg)
287288
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
288-
nullptr, nullptr, optionalParams);
289+
nullptr, nullptr, optionalParams, terminateReqId);
289290
mRecorder = recorder;
291+
mTerminateReqId = terminateReqId;
290292
}
291293

292294
~GptServer()
@@ -298,7 +300,7 @@ class GptServer
298300
{
299301
// Create InferenceRequest from a set of tensors
300302
auto request = std::make_shared<InferenceRequest>(requestId);
301-
if (requestId == -1)
303+
if (requestId == mTerminateReqId)
302304
{
303305
mWorkItemsQueue.push(request, requestId);
304306
return;
@@ -430,6 +432,7 @@ class GptServer
430432
std::shared_ptr<GptManager> mBatchManager;
431433
std::shared_ptr<Recorder> mRecorder;
432434
WorkItemsQueue mWorkItemsQueue;
435+
std::optional<uint64_t> mTerminateReqId;
433436

434437
}; // class GptServer
435438

@@ -479,11 +482,7 @@ void benchmarkGptManager(std::string const& modelName, std::filesystem::path con
479482
TLLM_LOG_ERROR(errStr);
480483
}
481484

482-
const int maxBeamWidth = 1;
483-
auto recorder = std::make_shared<Recorder>();
484-
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, maxNumSequences,
485-
maxTokensInPagedKvCache, kvCacheFreeGpuMemFraction, enableTrtOverlap, recorder);
486-
485+
// Load dataset
487486
auto dataset = parseDataset(datasetPath);
488487
std::vector<std::vector<NamedTensor>> tensors_list;
489488
const auto num_samples = dataset.first.size();
@@ -499,6 +498,12 @@ void benchmarkGptManager(std::string const& modelName, std::filesystem::path con
499498
tensors_list.push_back(tensors);
500499
}
501500

501+
const int maxBeamWidth = 1;
502+
auto recorder = std::make_shared<Recorder>();
503+
uint64_t terminateReqId = num_samples + 1;
504+
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, maxNumSequences,
505+
maxTokensInPagedKvCache, kvCacheFreeGpuMemFraction, enableTrtOverlap, recorder, terminateReqId);
506+
502507
if (worldConfig.getRank() == 0)
503508
{
504509
recorder->initialize();
@@ -510,8 +515,11 @@ void benchmarkGptManager(std::string const& modelName, std::filesystem::path con
510515
recorder->finalize();
511516
recorder->calculateMetrics();
512517
recorder->report();
513-
gptServer->enqueue({}, -1, false);
518+
// Send terminateReqId to terminate servers on all ranks
519+
// Sever on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
520+
gptServer->enqueue({}, terminateReqId, false);
514521
}
522+
// Wait until benchmarking is done and batch manager is terminated
515523
gptServer->waitBatchManager();
516524
}
517525

benchmarks/cpp/gptSessionBenchmark.cpp

+14-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ namespace
3636
{
3737
void benchmarkGptSession(std::string const& modelName, std::filesystem::path const& dataPath,
3838
std::vector<int> const& batchSizes, std::vector<std::vector<int>> const& inOutLen,
39-
std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp, int numRuns, int duration, bool cudaGraphMode)
39+
std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp, int numRuns, int duration,
40+
std::optional<SizeType> numMicroBatches, bool cudaGraphMode)
4041
{
4142
auto const json = GptJsonConfig::parse(dataPath / "config.json");
4243
auto const modelConfig = json.getModelConfig();
@@ -73,7 +74,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
7374

7475
for (auto const batchSize : batchSizes)
7576
{
76-
session.setup(batchSize, beamWidth, maxInputLength + maxNewTokens, decoderPerRequest);
77+
session.setup(
78+
batchSize, beamWidth, maxInputLength + maxNewTokens, decoderPerRequest, std::nullopt, numMicroBatches);
7779

7880
std::vector<SizeType> inputLenghtsHost(batchSize, maxInputLength);
7981
auto inputLenghts
@@ -163,6 +165,8 @@ int main(int argc, char* argv[])
163165
cxxopts::value<int>()->default_value("10"));
164166
options.add_options()("duration", "Minimal duration of iterations to measure in seconds.",
165167
cxxopts::value<int>()->default_value("60"));
168+
options.add_options()(
169+
"num_micro_batches", "Number of micro batches if enabling pipeline parallelism.", cxxopts::value<int>());
166170

167171
options.add_options()("enable_cuda_graph", "Execute GPT session with CUDA graph.");
168172

@@ -235,6 +239,13 @@ int main(int argc, char* argv[])
235239
return 1;
236240
}
237241

242+
// Argument: Number of micro batches
243+
std::optional<SizeType> numMicroBatches{std::nullopt};
244+
if (result.count("num_micro_batches"))
245+
{
246+
numMicroBatches = result["num_micro_batches"].as<int>();
247+
}
248+
238249
// Argument: Enable CUDA graph
239250
auto enableCudaGraph = result.count("enable_cuda_graph") > 0;
240251

@@ -244,7 +255,7 @@ int main(int argc, char* argv[])
244255
{
245256
benchmarkGptSession(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes,
246257
inOutLen, logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>(),
247-
enableCudaGraph);
258+
numMicroBatches, enableCudaGraph);
248259
}
249260
catch (const std::exception& e)
250261
{

benchmarks/python/all_reduce.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from argparse import ArgumentParser
17+
18+
import tensorrt as trt
19+
import torch
20+
from cuda import cuda, cudart
21+
from mpi4py import MPI
22+
from polygraphy.backend.trt import CreateConfig, EngineFromNetwork
23+
24+
import tensorrt_llm as tllm
25+
from tensorrt_llm import Mapping, Tensor
26+
from tensorrt_llm._ipc_utils import IpcMemory, peer_access
27+
from tensorrt_llm.functional import AllReduceStrategy, allreduce
28+
29+
30+
def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
31+
tllm.logger.set_level('error')
32+
world_size = tllm.mpi_world_size()
33+
rank = tllm.mpi_rank()
34+
35+
torch.cuda.set_device(rank)
36+
cudart.cudaSetDevice(rank)
37+
38+
mapping = Mapping(world_size, rank, world_size, world_size)
39+
40+
if world_size == 1:
41+
raise RuntimeError("Benchmark must run with mpi_world_size > 1")
42+
43+
ipc_barriers_in = IpcMemory(
44+
mapping, IpcMemory.IPC_BARRIERS_SIZE_PER_GPU * mapping.tp_size)
45+
ipc_barriers_out = IpcMemory(
46+
mapping, IpcMemory.IPC_BARRIERS_SIZE_PER_GPU * mapping.tp_size)
47+
torch_dtype = tllm._utils.str_dtype_to_torch(dtype)
48+
49+
min_size, max_size, ratio = [int(i) for i in test_range.split(",")]
50+
inner_loop = 1000
51+
52+
size = min_size
53+
while size < max_size:
54+
ipc_buffers = IpcMemory(mapping, size * 4)
55+
workspace = torch.tensor(ipc_buffers.serialize() +
56+
ipc_barriers_in.serialize() +
57+
ipc_barriers_out.serialize(),
58+
dtype=torch.int64,
59+
device="cpu")
60+
61+
input = torch.zeros(size, dtype=torch_dtype, device="cuda")
62+
63+
for strategy in [
64+
AllReduceStrategy.RING, AllReduceStrategy.ONESHOT,
65+
AllReduceStrategy.TWOSHOT
66+
]:
67+
builder = tllm.Builder()
68+
net = builder.create_network()
69+
net.plugin_config.set_nccl_plugin(dtype)
70+
71+
with tllm.net_guard(net):
72+
network = tllm.default_trtnet()
73+
74+
x = Tensor(name='x',
75+
shape=input.shape,
76+
dtype=tllm.str_dtype_to_trt(dtype))
77+
78+
w = Tensor(name='workspace',
79+
shape=workspace.shape,
80+
dtype=trt.int64)
81+
82+
current = x
83+
for i in range(inner_loop):
84+
current = allreduce(
85+
current, mapping.tp_group,
86+
w if strategy != AllReduceStrategy.RING else None, i,
87+
strategy)
88+
output = current.trt_tensor
89+
90+
output.name = 'output'
91+
output.dtype = tllm.str_dtype_to_trt(dtype)
92+
network.mark_output(output)
93+
94+
build_engine = EngineFromNetwork(
95+
(builder.trt_builder, net.trt_network),
96+
config=CreateConfig(
97+
fp16=(dtype == 'float16'),
98+
bf16=(dtype == 'bfloat16'),
99+
precision_constraints='obey',
100+
))
101+
102+
output = torch.zeros_like(input)
103+
104+
stream = torch.cuda.current_stream()
105+
feed_dict = {'x': input, 'workspace': workspace}
106+
107+
session = tllm.runtime.Session.from_engine(build_engine())
108+
_, start = cuda.cuEventCreate(0)
109+
_, stop = cuda.cuEventCreate(0)
110+
with peer_access(mapping):
111+
MPI.COMM_WORLD.barrier()
112+
113+
cuda.cuEventRecord(start, stream.cuda_stream)
114+
session.run(inputs=feed_dict,
115+
outputs={"output": output},
116+
stream=stream.cuda_stream)
117+
cuda.cuEventRecord(stop, stream.cuda_stream)
118+
torch.cuda.synchronize()
119+
_, ms = cuda.cuEventElapsedTime(start, stop)
120+
121+
if mapping.rank == 0:
122+
print(f"{size=}, {strategy=}, {ms=}")
123+
size *= ratio
124+
if mapping.rank == 0:
125+
print("")
126+
127+
128+
if __name__ == "__main__":
129+
parser = ArgumentParser()
130+
parser.add_argument("--dtype", "-t", default="float16")
131+
parser.add_argument("--range",
132+
"-r",
133+
default="256,25600000,10",
134+
help="min_size,max_size,multiplicative_ratio")
135+
args = parser.parse_args()
136+
137+
allreduce_benchmark(args.dtype, args.range)

benchmarks/python/benchmark.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import argparse
16+
import multiprocessing as mp
1617
from multiprocessing import Process, Queue
1718
from time import time
1819

@@ -38,11 +39,12 @@ def parse_arguments():
3839
'--mode',
3940
type=str,
4041
default="plugin",
41-
choices=['ootb', 'plugin'],
42+
choices=['ootb', 'plugin', 'ootb-except-mha'],
4243
help=
4344
('Choose mode between ootb/plugin. '
4445
'\"ootb\" means the engines will be built without any plugins, '
45-
'while \"plugin\" means the engines will be built with tuned recipe of using plugins.'
46+
'\"plugin\" means the engines will be built with tuned recipe of using plugins.'
47+
'\"ootb-except-mha\" means the engines will be built with only attention plugins.'
4648
))
4749

4850
parser.add_argument('--batch_size',
@@ -298,12 +300,16 @@ def main(args):
298300
)
299301

300302
except Exception as e:
303+
print("Found exception during benchmarking", e.with_traceback())
301304
p.kill()
302305
raise e
303-
306+
logger.debug("Sending signal to mem monitor process, start")
304307
q1.put(1)
308+
logger.debug("Sending signal to mem monitor process, done")
305309
peak_gpu_used = q2.get()
310+
logger.debug("Get peak gpu memory usage from mem monitor process, done")
306311
p.join()
312+
logger.debug("Memory monitor process joined")
307313

308314
latency = round(sum(latencies) / iter_idx, 3)
309315
latencies.sort()
@@ -318,5 +324,6 @@ def main(args):
318324

319325

320326
if __name__ == '__main__':
327+
mp.set_start_method('spawn')
321328
args = parse_arguments()
322329
main(args)

benchmarks/python/gpt_benchmark.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self,
5757
self.refit = refit
5858
self.num_beams = num_beams
5959
self.build_time = 0
60-
self.mode = mode # plugin or ootb
60+
self.mode = mode # plugin or ootb or ootb-except-mha
6161
self.fuse_bias = True
6262

6363
self.cuda_graph_mode = kwargs.get('enable_cuda_graph', False)
@@ -83,17 +83,20 @@ def __init__(self,
8383
self.per_token = False
8484
self.per_channel = False
8585

86-
is_plugin_mode = mode == 'plugin'
87-
plg_dtype = dtype if is_plugin_mode else False
88-
self.use_gpt_attention_plugin = plg_dtype
89-
self.use_gemm_plugin = plg_dtype
86+
use_mha_plugin = mode == 'plugin' or mode == 'ootb-except-mha'
87+
mha_plg_dtype = dtype if use_mha_plugin else False
88+
use_non_mha_plugin = mode == 'plugin'
89+
non_mha_plg_dtype = dtype if use_mha_plugin else False
90+
91+
self.use_gpt_attention_plugin = mha_plg_dtype
92+
self.use_gemm_plugin = non_mha_plg_dtype
9093
# Starting TRT9.1 OOTB norm layer sees improvement over plugin norm layer
9194
self.use_layernorm_plugin = False
9295
self.use_rmsnorm_plugin = False
93-
self.use_lookup_plugin = plg_dtype
94-
self.enable_context_fmha = True
96+
self.use_lookup_plugin = non_mha_plg_dtype
97+
self.enable_context_fmha = use_mha_plugin
9598
self.quant_mode = QuantMode(0)
96-
self.remove_input_padding = is_plugin_mode
99+
self.remove_input_padding = use_non_mha_plugin
97100

98101
for key, value in get_build_config(model_name).items():
99102
setattr(self, key, value)
@@ -135,8 +138,6 @@ def __init__(self,
135138
self.quant_mode = self.quant_mode.set_fp8_qdq()
136139

137140
if self.fp8_kv_cache:
138-
# Watch out, enable_fp8 and fp8_kv_cache are not exclusive
139-
assert self.use_gpt_attention_plugin, "GPT attention plugin needed"
140141
self.quant_mode = self.quant_mode.set_fp8_kv_cache()
141142

142143
engine_buffer = self.build()
@@ -151,7 +152,9 @@ def __init__(self,
151152
num_layers=self.num_layers,
152153
gpt_attention_plugin=self.use_gpt_attention_plugin,
153154
remove_input_padding=self.remove_input_padding,
154-
quant_mode=self.quant_mode)
155+
quant_mode=self.quant_mode,
156+
use_custom_all_reduce=self.enable_custom_all_reduce,
157+
)
155158
if model_name == 'chatglm_6b':
156159
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
157160
end_id=130005,
@@ -392,9 +395,6 @@ def build(self):
392395
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=self.dtype)
393396
network.plugin_config.set_layernorm_quantization_plugin(
394397
dtype=self.dtype)
395-
# FIXME(nkorobov)
396-
# See https://nvbugs/4164762
397-
# See https://nvbugs/4174113
398398
network.plugin_config.set_quantize_tensor_plugin()
399399
network.plugin_config.set_quantize_per_token_plugin()
400400
elif self.use_weight_only:

0 commit comments

Comments
 (0)