-
Notifications
You must be signed in to change notification settings - Fork 1k
/
plugin.py
695 lines (629 loc) · 25.5 KB
/
plugin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ctypes
import platform
from collections import OrderedDict
from dataclasses import asdict, dataclass, field, fields
from enum import IntEnum
from pathlib import Path
from textwrap import dedent
from typing import List, Optional, Tuple
import tensorrt as trt
from .._ipc_utils import IpcMemory, can_access_peer
from ..bindings.internal.runtime import lamport_initialize_all
from ..logger import logger
from ..mapping import Mapping
TRT_LLM_PLUGIN_NAMESPACE = 'tensorrt_llm'
def plugin_lib_path() -> str:
project_dir = Path(__file__).parent.parent.absolute()
dyn_lib = "libnvinfer_plugin_tensorrt_llm.so" if platform.system(
) != "Windows" else "nvinfer_plugin_tensorrt_llm.dll"
return str(project_dir.joinpath("libs", dyn_lib))
def _load_plugin_lib():
on_windows = platform.system() == "Windows"
winmode = 0 if on_windows else None
handle = ctypes.CDLL(plugin_lib_path(),
mode=ctypes.RTLD_GLOBAL,
winmode=winmode)
try:
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
handle.initTrtLlmPlugins.restype = ctypes.c_bool
except AttributeError as err:
raise ImportError('TensorRT-LLM Plugin is unavailable') from err
try:
assert handle.initTrtLlmPlugins(
None, TRT_LLM_PLUGIN_NAMESPACE.encode('utf-8'))
except OSError as e:
windows_err = """
The error above may be caused by an outdated Microsoft Visual C++ Redistributable Version.
Please install the latest MSVC from the link below and re-launch.
https://learn.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-170#latest-microsoft-visual-c-redistributable-version
"""
err_msg = dedent(windows_err if on_windows else "Unknown error")
raise RuntimeError(err_msg) from e
except Exception as e:
raise e
class ContextFMHAType(IntEnum):
disabled = 0
# FP16 I/O, FP16 Accumulation
enabled = 1
# FP16 I/O, FP32 Accumulation
enabled_with_fp32_acc = 2
DEFAULT_PLUGIN_DTYPE_OPTIONS = [
"auto", "float16", "float32", "bfloat16", "int32", None
]
PLUGIN_DTYPE_OPTIONS_MAP = {
"gemm_swiglu_plugin": ["fp8", None],
"gemm_plugin":
["auto", "float16", "float32", "bfloat16", "int32", "fp8", None],
"low_latency_gemm_plugin": ["fp8", None],
"low_latency_gemm_swiglu_plugin": ["fp8", None],
}
def _make_plugin_property(field_name: str, field_type: type):
def bind(field_name):
storage_name = f'_{field_name}'
@property
def prop(self):
field_value = getattr(self, storage_name)
if field_name != 'dtype' and field_value == 'auto':
return self.dtype
else:
return field_value
@prop.setter
def prop(self, value):
if field_type is bool:
assert isinstance(value, bool), \
f"Plugin {field_name} expects {field_type}, got {type(value)}"
elif field_type in (str, Optional[str]):
plugin_dtype_options = DEFAULT_PLUGIN_DTYPE_OPTIONS
if field_name in PLUGIN_DTYPE_OPTIONS_MAP:
plugin_dtype_options = PLUGIN_DTYPE_OPTIONS_MAP[field_name]
assert value in plugin_dtype_options, \
f"Plugin {field_name} expects values in {plugin_dtype_options}, got {value}"
if field_name == 'dtype':
assert value not in ['auto', None], \
"Plugin dtype cannot be auto or None"
setattr(self, storage_name, value)
logger.info(f"Set {field_name} to {value}.")
return prop
return bind(field_name)
class PluginConfigMeta(type):
def __new__(cls, name, bases, attrs):
for storage_name, field_type in attrs['__annotations__'].items():
assert storage_name.startswith('_')
field_name = storage_name.lstrip('_')
attrs[field_name] = _make_plugin_property(field_name, field_type)
return super().__new__(cls, name, bases, attrs)
@dataclass(slots=True)
class PluginConfig(metaclass=PluginConfigMeta):
"""The config that manages plugin-related options.
There are two option categories:
* Plugin options (typically with xxx_plugin naming). These options can be assigned with:
* "float16"/"bfloat16"/"float32"/"int32", which means the plugin is enabled with the specified precision; (Some plugins only support limited dtype, i.e., gemm_swiglu_plugin and low_latency_gemm_swiglu_plugin only supports fp8 now)
* "auto", which means the plugin is enabled with the precision of `dtype` field (the `dtype` field must be same to model dtype, i.e., the one in PretrainedConfig);
* None, which means the plugin is disabled.
* Other features. These options can be assigned with boolean:
* True, which means the plugin is enabled;
* False, which means the plugin is disabled.
Note: All the fields should use a prefix "_"; PluginConfigMeta will wrap each field as a property.
This ensures the fields can only be assigned with allowed values.
"""
_dtype: str = field(default="float16", init=False)
# Plugins
_bert_attention_plugin: Optional[str] = field(
default="auto",
init=False,
metadata={
"help":
"The plugin that uses efficient kernels and enables an in-place update of the KV cache for attention layer of BERT-like encoder models."
})
_gpt_attention_plugin: Optional[str] = field(
default="auto",
init=False,
metadata={
"help":
"The plugin that uses efficient kernels and enables an in-place update of the KV cache for attention layer of GPT-like decoder models."
})
_gemm_plugin: Optional[str] = field(
default=None,
init=False,
metadata={
"help":
"The GEMM plugin that utilizes NVIDIA cuBLASLt to perform GEMM operations. "
"Note: it's only affective for non-quantized gemm operations (except FP8)."
"Note: For FP8, it also requires same calibration in checkpoint."
})
_gemm_swiglu_plugin: Optional[str] = field(
default=None,
init=False,
metadata={
"help":
"The GEMM + SwiGLU fusion in Gated-MLP combines two Matmul operations and "
"one SwiGLU operation into a single kernel. Currently this is only supported for FP8 precision on Hopper."
})
_fp8_rowwise_gemm_plugin: Optional[str] = field(
default=None,
init=False,
metadata={
"help":
"The quantized GEMM for fp8, which uses per token dynamic scales for "
"activation and per channel static scales for weights."
"Note: It also requires same calibration in checkpoint."
})
_qserve_gemm_plugin: Optional[str] = field(
default=None,
init=False,
metadata={
"help":
"The quantized GEMM from [QServe](https://arxiv.org/abs/2405.04532), "
"which employs 4-bit quantization for weights and 8-bit quantization for activations."
})
_identity_plugin: Optional[str] = field(
default=None,
init=False,
metadata={
"help":
"The identity plugin simply copies inputs to outputs, it's used mostly for debugging purpose."
})
_nccl_plugin: Optional[str] = field(
default="auto",
init=False,
metadata={
"help":
"The NCCL plugin wraps NCCL operators to support multi-GPU and even multi-nodes."
})
_lora_plugin: Optional[str] = field(default=None,
init=False,
metadata={"help": "Enable LoRA."})
_weight_only_groupwise_quant_matmul_plugin: Optional[str] = field(
default=None,
init=False,
metadata={
"help":
"Enable weight-only groupwise quantization matmul operators."
})
_weight_only_quant_matmul_plugin: Optional[str] = field(
default=None,
init=False,
metadata={"help": "Enable weight-only quantization matmul operators."})
_smooth_quant_plugins: bool = field(
default=True,
init=False,
metadata={
"help": "Enable a group of plugins to support smooth quantization."
})
_smooth_quant_gemm_plugin: Optional[str] = field(
default=None,
init=False,
metadata={
"help":
"Enable plugin that supports smooth quantization gemm kernels."
})
_layernorm_quantization_plugin: Optional[str] = field(
default=None,
init=False,
metadata={
"help":
"Enable plugin that supports layernorm quantization kernels."
})
_rmsnorm_quantization_plugin: Optional[str] = field(
default=None,
init=False,
metadata={
"help": "Enable plugin that supports rmsnorm quantization kernels."
})
_quantize_per_token_plugin: bool = field(
default=False,
init=False,
metadata={
"help": "Enable plugin that supports per-token quantization."
})
_quantize_tensor_plugin: bool = field(
default=False,
init=False,
metadata={
"help": "Enable plugin that supports per-tensor quantization."
})
_moe_plugin: Optional[str] = field(
default="auto",
init=False,
metadata={
"help":
"Enable some customized kernels to speed up the MoE layer of MoE models."
})
_mamba_conv1d_plugin: Optional[str] = field(
default="auto",
init=False,
metadata={
"help":
"Enable customized kernels to speed up conv1d operator for Mamba."
})
_low_latency_gemm_plugin: Optional[str] = field(
default=None,
init=False,
metadata={
"help":
"The GEMM plugin that optimized specially for low latency scenarios."
})
_low_latency_gemm_swiglu_plugin: Optional[str] = field(
default=None,
init=False,
metadata={
"help":
"The GEMM + SwiGLU fusion plugin that optimized specially for low latency scenarios."
})
# Features
_context_fmha: bool = field(
default=True,
init=False,
metadata={
"help":
"Enable the fused multi-head attention during the context phase, "
"will trigger a kernel that performs the MHA/MQA/GQA block using a single kernel."
})
_bert_context_fmha_fp32_acc: bool = field(
default=False,
init=False,
metadata={
"help":
"Enable the FP32 accumulator for context FMHA in the bert_attention_plugin. "
"If disabled, FP16 is used, better performance but potentially worse accuracy is expected."
})
_paged_kv_cache: Optional[bool] = field(
default=None,
init=False,
metadata={
"help":
"Enable paged KV cache, which helps manage memory for the KV cache more efficiently, "
"and usually leads to an increase in the batch size and an improved efficiency."
})
_remove_input_padding: bool = field(
default=True,
init=False,
metadata={
"help":
"Pack different tokens together, which reduces both the amount of computations and memory consumption."
})
_reduce_fusion: bool = field(
default=False,
init=False,
metadata={
"help":
"Fuse the ResidualAdd and LayerNorm kernels after AllReduce into a single kernel, "
"resulting in improved end-to-end performance."
})
_user_buffer: bool = field(
default=False,
init=False,
metadata={
"help":
"Eliminate extra copies from the local buffer to the shared buffer "
"in the communication kernel, leading to improved end-to-end performance. "
"This feature must be enabled with `--reduce_fusion enable` and "
"is currently only supported for the FP8 LLAMA model."
})
_tokens_per_block: int = field(
default=64,
init=False,
metadata={
"help":
"Define how many tokens are contained in each paged kv cache block."
})
_use_paged_context_fmha: bool = field(
default=False,
init=False,
metadata={
"help":
"Allow advanced features like KV cache reuse and chunked context."
})
_use_fp8_context_fmha: bool = field(
default=False,
init=False,
metadata={
"help":
"When FP8 quantization is activated, the attention can be further accelerated by enabling FP8 Context FMHA"
})
_multiple_profiles: bool = field(
default=False,
init=False,
metadata={
"help":
"Enables multiple TensorRT optimization profiles in the built engines, "
"will benefits the performance especially when GEMM plugin is disabled, "
"because more optimization profiles help TensorRT have more chances to select better kernels. "
"Note: This feature increases engine build time but no other adverse effects are expected."
})
_paged_state: bool = field(
default=True,
init=False,
metadata={
"help":
"Enable paged state, which helps manage memory for the RNN state more efficiently."
})
_streamingllm: bool = field(
default=False,
init=False,
metadata={
"help":
"Enable [StreamingLLM](https://arxiv.org/abs/2309.17453), which uses a window attention to perform efficient and stable LLM on long texts."
})
_manage_weights: bool = field(
default=False,
init=False,
metadata={
"help":
"Enable TensorRT-LLM managed weights to speed up engine building process."
})
_use_fused_mlp: bool = field(
default=True,
init=False,
metadata={
"help":
"Enable horizontal fusion in Gated-MLP that combines two Matmul "
"operations into a single one followed by a separate SwiGLU kernel."
})
_pp_reduce_scatter: bool = field(
default=False,
init=False,
metadata={
"help":
"Enable a pipeline parallelism optimization with "
"ReduceScatter + AllGather targeting large MoE models."
})
def update_from_dict(self, config: dict):
for name in config.keys():
if hasattr(self, name):
value_to_be_update = config[name]
if isinstance(getattr(self, name),
bool) or name == 'paged_kv_cache':
if value_to_be_update == "enable":
value_to_be_update = True
elif value_to_be_update == "disable":
value_to_be_update = False
elif value_to_be_update == "disable":
value_to_be_update = None
setattr(self, name, value_to_be_update)
@classmethod
def from_dict(cls, config: dict):
plugin_config = cls()
plugin_config.update_from_dict(config)
return plugin_config
@classmethod
def from_arguments(cls, args: argparse.Namespace):
return cls.from_dict(vars(args))
def to_dict(self):
config = asdict(self)
# Remove prefix "_" of the storage name
config = {key.lstrip('_'): value for key, value in config.items()}
return config
def to_legacy_setting(self):
'''Legacy setting means that all of the plugins and features are
disabled, this is needed for the legacy `build.py` script, which will be
migrated to the centralized building script `tensorrt_llm/commands/build.py`.
After the migration is done, this function may or may not be deleted.
'''
for field in fields(self):
# Remove prefix "_" of the storage name
field_name = field.name.lstrip('_')
if field_name == 'dtype':
continue
if field.type in (str, Optional[str]):
setattr(self, field_name, None)
elif field.type == bool or field_name == 'paged_kv_cache':
setattr(self, field_name, False)
@property
def context_fmha_type(self):
if self.bert_context_fmha_fp32_acc:
return ContextFMHAType.enabled_with_fp32_acc
elif self.context_fmha:
return ContextFMHAType.enabled
else:
return ContextFMHAType.disabled
def is_context_fmha_enabled(self):
return self.context_fmha_type != ContextFMHAType.disabled
@context_fmha_type.setter
def context_fmha_type(self, value):
if value == ContextFMHAType.disabled:
self.context_fmha = False
self.bert_context_fmha_fp32_acc = False
else:
self.context_fmha = True
if value == ContextFMHAType.enabled:
self.bert_context_fmha_fp32_acc = False
elif value == ContextFMHAType.enabled_with_fp32_acc:
self.bert_context_fmha_fp32_acc = True
def set_smooth_quant_plugins(self, dtype: str = "auto"):
self.smooth_quant_gemm_plugin = dtype
self.rmsnorm_quantization_plugin = dtype
self.layernorm_quantization_plugin = dtype
self.quantize_per_token_plugin = True
self.quantize_tensor_plugin = True
return self
def set_qserve_plugins(self, dtype: str = "auto"):
self.qserve_gemm_plugin = dtype
self.rmsnorm_quantization_plugin = dtype
self.quantize_per_token_plugin = True
return self
def set_fp8_rowwise_quant_plugins(self, dtype: str = "auto"):
self.fp8_rowwise_gemm_plugin = dtype
self.rmsnorm_quantization_plugin = dtype
# self.layernorm_quantization_plugin = dtype
self.quantize_per_token_plugin = True
self.quantize_tensor_plugin = True
return self
def set_context_fmha(self, context_fmha_type=ContextFMHAType.enabled):
assert type(context_fmha_type) == ContextFMHAType
self.context_fmha_type = context_fmha_type
return self
def enable_paged_kv_cache(self, tokens_per_block: int = 64):
self.paged_kv_cache = True
self.tokens_per_block = tokens_per_block
return self
def set_nccl_plugin(self, dtype: str = "auto"):
self.nccl_plugin = dtype
init_all_reduce_helper()
return self
# Only plugin configs in this list will be exposed as `trtllm-build` arguments,
# others are automatically enabled when needed, no need for users to control.
cli_plugin_args = [
# Plugins
"bert_attention_plugin",
"gpt_attention_plugin",
"gemm_plugin",
"gemm_swiglu_plugin",
"fp8_rowwise_gemm_plugin",
"lora_plugin",
"moe_plugin",
"mamba_conv1d_plugin",
"nccl_plugin",
"low_latency_gemm_plugin",
"low_latency_gemm_swiglu_plugin",
# Features
"context_fmha",
"bert_context_fmha_fp32_acc",
"remove_input_padding",
"tokens_per_block",
"use_paged_context_fmha",
"use_fp8_context_fmha",
"multiple_profiles",
"paged_state",
"streamingllm",
"reduce_fusion",
"user_buffer",
"use_fused_mlp",
"pp_reduce_scatter",
]
def add_plugin_argument(parser: argparse.ArgumentParser):
plugin_config = PluginConfig()
for field in fields(plugin_config):
# Remove prefix "_" of the storage name
field_name = field.name.lstrip('_')
if field_name not in cli_plugin_args:
continue
if field.metadata and "help" in field.metadata:
help_message = field.metadata["help"]
else:
raise AttributeError(f"Please add help message for {field_name}.")
if field.type in (str, Optional[str]):
plugin_dtype_options = DEFAULT_PLUGIN_DTYPE_OPTIONS
if field_name in PLUGIN_DTYPE_OPTIONS_MAP:
plugin_dtype_options = PLUGIN_DTYPE_OPTIONS_MAP[field_name]
parser.add_argument(
"--" + field_name,
type=str,
default=field.default if field.default else "disable",
choices=[x if x else "disable" for x in plugin_dtype_options],
help=help_message)
elif field.type == bool:
parser.add_argument(
"--" + field_name,
type=str,
default="enable" if field.default else "disable",
choices=["enable", "disable"],
help=help_message)
else:
parser.add_argument("--" + field_name,
type=field.type,
default=field.default,
help=help_message)
return parser
class CustomAllReduceHelper:
"""
Globally visible class to help usage of custom_all_reduce plugin.
Provides the following utilities:
workspace: Tensor
When using CUSTOM or AUTO mode, a tensor containing pointers to memory
visible to all GPUs. It should be 3 pointers per TP rank -
ptr to data buffer, ptr to barriers in, ptr to barriers out.
It must be initialized using IpcMemory class.
Usage:
- Set custom_all_reduce_helper.workspace with the required tensor.
Then, each instance of allreduce will reference that tensor automatically.
"""
POINTERS_PER_RANK = 7
POINTERS_OF_COUNTER = 2
def __init__(self) -> None:
self.workspace: Optional[Tensor] = None
def set_workspace_tensor(self,
mapping: Mapping,
num_profiles: Optional[int] = None):
from ..functional import Tensor
workspace_size = self.POINTERS_PER_RANK * mapping.tp_size + self.POINTERS_OF_COUNTER
dim_range = None
if num_profiles is not None:
dim_range = OrderedDict([('all_reduce_size',
[workspace_size] * num_profiles)])
self.workspace = Tensor(
name='all_reduce_workspace',
dtype=trt.int64,
shape=[workspace_size],
dim_range=dim_range,
)
@staticmethod
def max_workspace_size_auto(tp_size: int) -> int:
if tp_size <= 2:
return 16_000_000
return 8_000_000
@staticmethod
def allocate_workspace(mapping: Mapping,
size: int) -> Tuple[List[IpcMemory], "torch.tensor"]:
import torch
is_p2p_supported = can_access_peer(mapping)
ipc_buffers_ping = IpcMemory(mapping, size * mapping.tp_size,
is_p2p_supported)
ipc_buffers_pong = IpcMemory(mapping, size * mapping.tp_size,
is_p2p_supported)
ipc_barriers_in = IpcMemory(
mapping, IpcMemory.IPC_BARRIERS_SIZE_PER_GPU * mapping.tp_size * 2,
is_p2p_supported)
ipc_barriers_out = IpcMemory(
mapping, IpcMemory.IPC_BARRIERS_SIZE_PER_GPU * mapping.tp_size * 2,
is_p2p_supported)
lamport_buffers_0 = IpcMemory(mapping, size * mapping.tp_size,
is_p2p_supported)
lamport_buffers_1 = IpcMemory(mapping, size * mapping.tp_size,
is_p2p_supported)
lamport_buffers_2 = IpcMemory(mapping, size * mapping.tp_size,
is_p2p_supported)
rank = mapping.rank
tp_rank = mapping.tp_rank
if rank == tp_rank and is_p2p_supported:
lamport_initialize_all(
lamport_buffers_0.local_ptr,
lamport_buffers_1.local_ptr,
lamport_buffers_2.local_ptr,
size * mapping.tp_size,
)
buffers = [
ipc_buffers_ping, ipc_buffers_pong, ipc_barriers_in,
ipc_barriers_out, lamport_buffers_0, lamport_buffers_1,
lamport_buffers_2
]
return buffers, torch.tensor(
ipc_buffers_ping.serialize() + ipc_buffers_pong.serialize() +
ipc_barriers_in.serialize() + ipc_barriers_out.serialize() +
lamport_buffers_0.serialize() + lamport_buffers_1.serialize() +
lamport_buffers_2.serialize() + [0] + [0],
dtype=torch.int64,
device="cpu")
custom_all_reduce_helper = None
def init_all_reduce_helper():
global custom_all_reduce_helper
custom_all_reduce_helper = CustomAllReduceHelper()
def current_all_reduce_helper():
global custom_all_reduce_helper
assert custom_all_reduce_helper is not None, "You must call `init_all_reduce_helper` first"
return custom_all_reduce_helper