-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompute_epsilons.py
62 lines (54 loc) · 2.33 KB
/
compute_epsilons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from datasets import load_dataset_builder
import numpy as np
import argparse
import utils
def main():
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--dataset", type=str, required=True)
arg_parser.add_argument("--lang_pair", type=str, default='de-en')
arg_parser.add_argument("--batch_size", type=int, required=True)
arg_parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Scale up the batch size")
arg_parser.add_argument("--noise_multiplier", type=float, default=None)
arg_parser.add_argument("--device_count", type=int, required=True)
arg_parser.add_argument("--epochs", type=int, required=True)
args = arg_parser.parse_args()
# Set values
total_batch_size = args.batch_size * args.device_count
epochs = args.epochs
ds_builder = load_dataset_builder(args.dataset, args.lang_pair)
len_train_dataset = ds_builder.info.splits['train'].num_examples
num_batch, remainder = divmod(len_train_dataset, total_batch_size)
actual_compute_len_train = num_batch * total_batch_size if remainder == 0 else (num_batch + 1) * total_batch_size
if args.noise_multiplier is None:
noise_multipliers = np.concatenate(
(np.arange(0.0, 1.0, 0.01),
np.arange(1.0, 5.0, 0.2),
np.arange(5.0, 100, 5.0),
np.array([128, 256])
)
)
else:
noise_multipliers = [args.noise_multiplier]
print("Total number of training examples:", actual_compute_len_train)
print("Number of devices:", args.device_count)
print("Total batch size:", total_batch_size)
print("Gradient accumulation steps:", args.gradient_accumulation_steps)
print("Accumulation batch size:", total_batch_size * args.gradient_accumulation_steps)
print("Sampling rate:", (total_batch_size * args.gradient_accumulation_steps) / actual_compute_len_train)
print("Epochs:", epochs)
for noise_multiplier in noise_multipliers:
print("Noise multiplier:", noise_multiplier)
epsilon = utils.compute_epsilons(
actual_compute_len_train,
total_batch_size * args.gradient_accumulation_steps,
noise_multiplier,
epochs
)
print("Resulting epsilon:", epsilon)
print("\n")
if __name__ == '__main__':
main()