-
Notifications
You must be signed in to change notification settings - Fork 286
/
zipformer.py
2462 lines (2099 loc) · 91.5 KB
/
zipformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import math
import random
import warnings
from typing import List, Optional, Tuple, Union
import torch
from encoder_interface import EncoderInterface
from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
)
from scaling import (
ActivationDropoutAndLinear,
Balancer,
BiasNorm,
ChunkCausalDepthwiseConv1d,
Dropout2,
FloatLike,
ScheduledFloat,
Whiten,
convert_num_channels,
limit_param_value,
penalize_abs_values_gt,
softmax,
)
from torch import Tensor, nn
class Zipformer2(EncoderInterface):
"""
Args:
Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length
as downsampling_factor if they are single ints or one-element tuples. The length of
downsampling_factor defines the number of stacks.
output_downsampling_factor (int): how much to downsample at the output. Note:
we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
You should probably leave this at 2.
downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
Note: this is in addition to the downsampling factor of 2 that is applied in
the frontend (self.encoder_embed).
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
encoder stack.
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
the encoder stacks for purposes of per-frame dropout (recommend 256 for
now).
query_head_dim (int or Tuple[int]): dimension of query and key per attention
head: per stack, if a tuple..
pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
attention head
value_head_dim (int or Tuple[int]): dimension of value in each attention head
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
Must be at least 4.
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
pos_dim (int): the dimension of each positional-encoding vector prior to projection,
e.g. 128.
dropout (float): dropout rate
warmup_batches (float): number of batches to warm up over; this controls
dropout of encoder layers.
causal (bool): if True, support chunkwise causal convolution. This should
not hurt WER as no modeling power is lost, but the convolution modules will be
slightly slower and use more memory. Enables use of the chunk_size and
left_context_chunks options in forward(), which simulates streaming
decoding.
chunk_size: (list of int): only set this to other than [-1] if causal;
the chunk size will be randomly chosen from this list. -1 means no chunking.
left_context_frames: (list of int): determines the number of left-
context chunks for causal training; will be rounded to a number of
chunks. Must not be less than cnn_module_kernel (after factoring in
rounding and downsampling); an error will be thrown if this is violated.
"""
def __init__(
self,
output_downsampling_factor: int = 2,
downsampling_factor: Tuple[int] = (2, 4),
encoder_dim: Union[int, Tuple[int]] = 384,
num_encoder_layers: Union[int, Tuple[int]] = 4,
encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
query_head_dim: Union[int, Tuple[int]] = 24,
pos_head_dim: Union[int, Tuple[int]] = 4,
value_head_dim: Union[int, Tuple[int]] = 12,
num_heads: Union[int, Tuple[int]] = 8,
feedforward_dim: Union[int, Tuple[int]] = 1536,
cnn_module_kernel: Union[int, Tuple[int]] = 31,
pos_dim: int = 192,
dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0,
causal: bool = False,
chunk_size: Tuple[int] = [-1],
left_context_frames: Tuple[int] = [-1],
) -> None:
super(Zipformer2, self).__init__()
if dropout is None:
dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
def _to_tuple(x):
"""Converts a single int or a 1-tuple of an int to a tuple with the same length
as downsampling_factor"""
if isinstance(x, int):
x = (x,)
if len(x) == 1:
x = x * len(downsampling_factor)
else:
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
return x
self.output_downsampling_factor = output_downsampling_factor # int
self.downsampling_factor = downsampling_factor # tuple
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
encoder_unmasked_dim
) # tuple
num_encoder_layers = _to_tuple(num_encoder_layers)
self.num_encoder_layers = num_encoder_layers
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
pos_head_dim = _to_tuple(pos_head_dim)
self.num_heads = num_heads = _to_tuple(num_heads)
feedforward_dim = _to_tuple(feedforward_dim)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
self.causal = causal
self.chunk_size = chunk_size
self.left_context_frames = left_context_frames
for u, d in zip(encoder_unmasked_dim, encoder_dim):
assert u <= d
# each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
encoders = []
num_encoders = len(downsampling_factor)
for i in range(num_encoders):
encoder_layer = Zipformer2EncoderLayer(
embed_dim=encoder_dim[i],
pos_dim=pos_dim,
num_heads=num_heads[i],
query_head_dim=query_head_dim[i],
pos_head_dim=pos_head_dim[i],
value_head_dim=value_head_dim[i],
feedforward_dim=feedforward_dim[i],
dropout=dropout,
cnn_module_kernel=cnn_module_kernel[i],
causal=causal,
)
# For the segment of the warmup period, we let the Conv2dSubsampling
# layer learn something. Then we start to warm up the other encoders.
encoder = Zipformer2Encoder(
encoder_layer,
num_encoder_layers[i],
pos_dim=pos_dim,
dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
)
if downsampling_factor[i] != 1:
encoder = DownsampledZipformer2Encoder(
encoder,
dim=encoder_dim[i],
downsample=downsampling_factor[i],
dropout=dropout,
causal=causal,
)
encoders.append(encoder)
self.encoders = nn.ModuleList(encoders)
self.downsample_output = SimpleDownsample(
max(encoder_dim),
downsample=output_downsampling_factor,
dropout=dropout,
causal=causal,
)
def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
"""
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
randomized feature masks, one per encoder.
On e.g. 15% of frames, these masks will zero out all encoder dims larger than
some supplied number, e.g. >256, so in effect on those frames we are using
a smaller encoder dim.
We generate the random masks at this level because we want the 2 masks to 'agree'
all the way up the encoder stack. This will mean that the 1st mask will have
mask values repeated self.zipformer_subsampling_factor times.
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
(1, batch_size, encoder_dims0)
"""
num_encoders = len(self.encoder_dim)
if not self.training:
return [1.0] * num_encoders
(num_frames0, batch_size, _encoder_dims0) = x.shape
assert self.encoder_dim[0] == _encoder_dims0, (
self.encoder_dim[0],
_encoder_dims0,
)
feature_mask_dropout_prob = 0.125
# mask1 shape: (1, batch_size, 1)
mask1 = (
torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
).to(x.dtype)
# mask2 has additional sequences masked, about twice the number.
mask2 = torch.logical_and(
mask1,
(
torch.rand(1, batch_size, 1, device=x.device)
> feature_mask_dropout_prob
).to(x.dtype),
)
# dim: (1, batch_size, 2)
mask = torch.cat((mask1, mask2), dim=-1)
feature_masks = []
for i in range(num_encoders):
channels = self.encoder_dim[i]
feature_mask = torch.ones(
1, batch_size, channels, dtype=x.dtype, device=x.device
)
u1 = self.encoder_unmasked_dim[i]
u2 = u1 + (channels - u1) // 2
feature_mask[:, :, u1:u2] *= mask[..., 0:1]
feature_mask[:, :, u2:] *= mask[..., 1:2]
feature_masks.append(feature_mask)
return feature_masks
def get_chunk_info(self) -> Tuple[int, int]:
"""
Returns chunk_size and left_context_chunks.
"""
if not self.causal:
return -1, -1
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert len(self.chunk_size) == 1, self.chunk_size
chunk_size = self.chunk_size[0]
else:
chunk_size = random.choice(self.chunk_size)
if chunk_size == -1:
left_context_chunks = -1
else:
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert len(self.left_context_frames) == 1, self.left_context_frames
left_context_frames = self.left_context_frames[0]
else:
left_context_frames = random.choice(self.left_context_frames)
# Note: in Python, -1 // n == -1 for n > 0
left_context_chunks = left_context_frames // chunk_size
if left_context_chunks == 0:
left_context_chunks = 1
return chunk_size, left_context_chunks
def forward(
self,
x: Tensor,
x_lens: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x:
The input tensor. Its shape is (seq_len, batch_size, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
src_key_padding_mask:
The mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
Returns:
Return a tuple containing 2 tensors:
- embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
outputs = []
if torch.jit.is_scripting() or torch.jit.is_tracing():
feature_masks = [1.0] * len(self.encoder_dim)
else:
feature_masks = self.get_feature_masks(x)
chunk_size, left_context_chunks = self.get_chunk_info()
if torch.jit.is_scripting() or torch.jit.is_tracing():
# Not support exporting a model for simulating streaming decoding
attn_mask = None
else:
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
for i, module in enumerate(self.encoders):
ds = self.downsampling_factor[i]
x = convert_num_channels(x, self.encoder_dim[i])
x = module(
x,
chunk_size=chunk_size,
feature_mask=feature_masks[i],
src_key_padding_mask=(
None
if src_key_padding_mask is None
else src_key_padding_mask[..., ::ds]
),
attn_mask=attn_mask,
)
outputs.append(x)
# if the last output has the largest dimension, x will be unchanged,
# it will be the same as outputs[-1]. Otherwise it will be concatenated
# from different pieces of 'outputs', taking each dimension from the
# most recent output that has it present.
x = self._get_full_dim_output(outputs)
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2, self.output_downsampling_factor
if torch.jit.is_scripting() or torch.jit.is_tracing():
lengths = (x_lens + 1) // 2
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = (x_lens + 1) // 2
return x, lengths
def _get_attn_mask(
self, x: Tensor, chunk_size: int, left_context_chunks: int
) -> Optional[Tensor]:
"""
Return None if chunk_size == -1, else return attention mask of shape
(seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
means a masked position.
Args:
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
chunk_size: chunk size, must divide
"""
if chunk_size <= 0:
return None
assert all(chunk_size % d == 0 for d in self.downsampling_factor)
if left_context_chunks >= 0:
num_encoders = len(self.encoder_dim)
assert all(
chunk_size * left_context_chunks
>= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
for i in range(num_encoders)
)
else:
left_context_chunks = 1000000
seq_len = x.shape[0]
# t is frame index, shape (seq_len,)
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
# c is chunk index for each frame, shape (seq_len,)
if torch.jit.is_scripting() or torch.jit.is_tracing():
c = t // chunk_size
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
c = t // chunk_size
src_c = c
tgt_c = c.unsqueeze(-1)
attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks)
if __name__ == "__main__":
logging.info(f"attn_mask = {attn_mask}")
return attn_mask
def _get_full_dim_output(self, outputs: List[Tensor]):
num_encoders = len(self.encoder_dim)
assert len(outputs) == num_encoders
output_dim = max(self.encoder_dim)
output_pieces = [outputs[-1]]
cur_dim = self.encoder_dim[-1]
for i in range(num_encoders - 2, -1, -1):
d = self.encoder_dim[i]
if d > cur_dim:
this_output = outputs[i]
output_pieces.append(this_output[..., cur_dim:d])
cur_dim = d
assert cur_dim == output_dim
return torch.cat(output_pieces, dim=-1)
def streaming_forward(
self,
x: Tensor,
x_lens: Tensor,
states: List[Tensor],
src_key_padding_mask: Tensor,
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""
Args:
x:
The input tensor. Its shape is (seq_len, batch_size, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
states: list of cached tensors of all encoder layers. For layer-i,
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
cached_conv1, cached_conv2).
src_key_padding_mask:
The mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
Returns:
Return a tuple containing 2 tensors:
- embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
- updated states
"""
outputs = []
new_states = []
layer_offset = 0
for i, module in enumerate(self.encoders):
num_layers = module.num_layers
ds = self.downsampling_factor[i]
x = convert_num_channels(x, self.encoder_dim[i])
x, new_layer_states = module.streaming_forward(
x,
states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
left_context_len=self.left_context_frames[0] // ds,
src_key_padding_mask=src_key_padding_mask[..., ::ds],
)
layer_offset += num_layers
outputs.append(x)
new_states += new_layer_states
# if the last output has the largest dimension, x will be unchanged,
# it will be the same as outputs[-1]. Otherwise it will be concatenated
# from different pieces of 'outputs', taking each dimension from the
# most recent output that has it present.
x = self._get_full_dim_output(outputs)
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2
if torch.jit.is_scripting() or torch.jit.is_tracing():
lengths = (x_lens + 1) // 2
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = (x_lens + 1) // 2
return x, lengths, new_states
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[Tensor]:
"""Get initial states.
A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
"""
states = []
for i, module in enumerate(self.encoders):
num_layers = module.num_layers
embed_dim = self.encoder_dim[i]
ds = self.downsampling_factor[i]
num_heads = self.num_heads[i]
key_dim = self.query_head_dim[i] * num_heads
value_dim = self.value_head_dim[i] * num_heads
downsample_left = self.left_context_frames[0] // ds
nonlin_attn_head_dim = 3 * embed_dim // 4
conv_left_pad = self.cnn_module_kernel[i] // 2
for layer in range(num_layers):
cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
device
)
cached_nonlin_attn = torch.zeros(
1, batch_size, downsample_left, nonlin_attn_head_dim
).to(device)
cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
device
)
cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
device
)
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
device
)
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
device
)
states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
return states
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
def _balancer_schedule(min_prob: float):
return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
class Zipformer2EncoderLayer(nn.Module):
"""
Args:
embed_dim: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
feedforward_dim: the dimension of the feedforward network model (required).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module (default=31).
Examples::
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = encoder_layer(src, pos_emb)
"""
def __init__(
self,
embed_dim: int,
pos_dim: int,
num_heads: int,
query_head_dim: int,
pos_head_dim: int,
value_head_dim: int,
feedforward_dim: int,
dropout: FloatLike = 0.1,
cnn_module_kernel: int = 31,
causal: bool = False,
attention_skip_rate: FloatLike = ScheduledFloat(
(0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
),
conv_skip_rate: FloatLike = ScheduledFloat(
(0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
),
const_attention_rate: FloatLike = ScheduledFloat(
(0.0, 0.25), (4000.0, 0.025), default=0
),
ff2_skip_rate: FloatLike = ScheduledFloat(
(0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
),
ff3_skip_rate: FloatLike = ScheduledFloat(
(0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
),
bypass_skip_rate: FloatLike = ScheduledFloat(
(0.0, 0.5), (4000.0, 0.02), default=0
),
) -> None:
super(Zipformer2EncoderLayer, self).__init__()
self.embed_dim = embed_dim
# self.bypass implements layer skipping as well as bypass; see its default values.
self.bypass = BypassModule(
embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
)
# bypass_mid is bypass used in the middle of the layer.
self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
# skip probability for dynamic modules (meaning: anything but feedforward).
self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
# an additional skip probability that applies to ConvModule to stop it from
# contributing too much early on.
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
# ff2_skip_rate is to prevent the ff2 module from having output that's too big
# compared to its residual.
self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
self.const_attention_rate = copy.deepcopy(const_attention_rate)
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
embed_dim,
pos_dim=pos_dim,
num_heads=num_heads,
query_head_dim=query_head_dim,
pos_head_dim=pos_head_dim,
dropout=0.0,
)
self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
self.feed_forward1 = FeedforwardModule(
embed_dim, (feedforward_dim * 3) // 4, dropout
)
self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
self.feed_forward3 = FeedforwardModule(
embed_dim, (feedforward_dim * 5) // 4, dropout
)
self.nonlin_attention = NonlinAttention(
embed_dim, hidden_channels=3 * embed_dim // 4
)
self.conv_module1 = ConvolutionModule(
embed_dim, cnn_module_kernel, causal=causal
)
self.conv_module2 = ConvolutionModule(
embed_dim, cnn_module_kernel, causal=causal
)
# TODO: remove it
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
self.norm = BiasNorm(embed_dim)
self.balancer1 = Balancer(
embed_dim,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
min_abs=0.2,
max_abs=4.0,
)
# balancer for output of NonlinAttentionModule
self.balancer_na = Balancer(
embed_dim,
channel_dim=-1,
min_positive=0.3,
max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
prob=0.05, # out of concern for memory usage
)
# balancer for output of feedforward2, prevent it from staying too
# small. give this a very small probability, even at the start of
# training, it's to fix a rare problem and it's OK to fix it slowly.
self.balancer_ff2 = Balancer(
embed_dim,
channel_dim=-1,
min_positive=0.3,
max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
max_abs=2.0,
prob=0.05,
)
self.balancer_ff3 = Balancer(
embed_dim,
channel_dim=-1,
min_positive=0.3,
max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
max_abs=4.0,
prob=0.05,
)
self.whiten = Whiten(
num_groups=1,
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
prob=(0.025, 0.25),
grad_scale=0.01,
)
self.balancer2 = Balancer(
embed_dim,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
min_abs=0.1,
max_abs=4.0,
)
def get_sequence_dropout_mask(
self, x: Tensor, dropout_rate: float
) -> Optional[Tensor]:
if (
dropout_rate == 0.0
or not self.training
or torch.jit.is_scripting()
or torch.jit.is_tracing()
):
return None
batch_size = x.shape[1]
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
return mask
def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
"""
Apply sequence-level dropout to x.
x shape: (seq_len, batch_size, embed_dim)
"""
dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
if dropout_mask is None:
return x
else:
return x * dropout_mask
def forward(
self,
src: Tensor,
pos_emb: Tensor,
chunk_size: int = -1,
attn_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
True means masked position. May be None.
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
Returns:
A tensor which has the same shape as src
"""
src_orig = src
# dropout rate for non-feedforward submodules
if torch.jit.is_scripting() or torch.jit.is_tracing():
attention_skip_rate = 0.0
else:
attention_skip_rate = (
float(self.attention_skip_rate) if self.training else 0.0
)
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
attn_weights = self.self_attn_weights(
src,
pos_emb=pos_emb,
attn_mask=attn_mask,
key_padding_mask=src_key_padding_mask,
)
src = src + self.feed_forward1(src)
self_attn_dropout_mask = self.get_sequence_dropout_mask(
src, attention_skip_rate
)
selected_attn_weights = attn_weights[0:1]
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif self.training and random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to
# encourage these modules to do something similar to an
# averaging-over-time operation.
# only need the mask, can just use the 1st one and expand later
selected_attn_weights = selected_attn_weights[0:1]
selected_attn_weights = (selected_attn_weights > 0.0).to(
selected_attn_weights.dtype
)
selected_attn_weights = selected_attn_weights * (
1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
)
na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
src = src + (
na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
)
self_attn = self.self_attn1(src, attn_weights)
src = src + (
self_attn
if self_attn_dropout_mask is None
else self_attn * self_attn_dropout_mask
)
if torch.jit.is_scripting() or torch.jit.is_tracing():
conv_skip_rate = 0.0
else:
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
src = src + self.sequence_dropout(
self.conv_module1(
src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
),
conv_skip_rate,
)
if torch.jit.is_scripting() or torch.jit.is_tracing():
ff2_skip_rate = 0.0
else:
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
src = src + self.sequence_dropout(
self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
)
# bypass in the middle of the layer.
src = self.bypass_mid(src_orig, src)
self_attn = self.self_attn2(src, attn_weights)
src = src + (
self_attn
if self_attn_dropout_mask is None
else self_attn * self_attn_dropout_mask
)
if torch.jit.is_scripting() or torch.jit.is_tracing():
conv_skip_rate = 0.0
else:
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
src = src + self.sequence_dropout(
self.conv_module2(
src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
),
conv_skip_rate,
)
if torch.jit.is_scripting() or torch.jit.is_tracing():
ff3_skip_rate = 0.0
else:
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
src = src + self.sequence_dropout(
self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
)
src = self.balancer1(src)
src = self.norm(src)
src = self.bypass(src_orig, src)
src = self.balancer2(src)
src = self.whiten(src)
return src
def streaming_forward(
self,
src: Tensor,
pos_emb: Tensor,
cached_key: Tensor,
cached_nonlin_attn: Tensor,
cached_val1: Tensor,
cached_val2: Tensor,
cached_conv1: Tensor,
cached_conv2: Tensor,
left_context_len: int,
src_key_padding_mask: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Pass the input through the encoder layer in streaming forward mode.
Args:
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or
(batch_size, left_context_len+2*seq_len-1, pos_emb_dim)
cached_key: cached attention key tensor of left context,
of shape (left_context_len, batch_size, key_dim)
cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape
(num_heads, batch_size, left_context_len, head_dim)
cached_val1: cached left context for the first attention module,
of shape (left_context_len, batch_size, value_dim)
cached_val2: cached left context for the second attention module,
of shape (left_context_len, batch_size, value_dim)
cached_conv1: cached left context for the first convolution module,
of shape (batch_size, channels, left_pad)
cached_conv2: cached left context for the second convolution module,
of shape (batch_size, channels, left_pad)
left_context_len: number of left context frames.
src_key_padding_mask: the mask for padding, of shape
(batch_size, left_context_len + seq_len); True means masked position.
May be None.
Returns:
- x, with the same shape as src
- updated cached_key
- updated cached_nonlin_attn
- updated cached_val1
- updated cached_val2
- updated cached_conv1
- updated cached_conv2
"""
src_orig = src
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
attn_weights, cached_key = self.self_attn_weights.streaming_forward(
src,
pos_emb=pos_emb,
cached_key=cached_key,
left_context_len=left_context_len,
key_padding_mask=src_key_padding_mask,
)
src = src + self.feed_forward1(src)
na, cached_nonlin_attn = self.nonlin_attention.streaming_forward(
src,
attn_weights[0:1],
cached_x=cached_nonlin_attn,
left_context_len=left_context_len,
)
src = src + na
self_attn, cached_val1 = self.self_attn1.streaming_forward(
src,
attn_weights=attn_weights,
cached_val=cached_val1,
left_context_len=left_context_len,
)
src = src + self_attn
src_conv, cached_conv1 = self.conv_module1.streaming_forward(
src,
cache=cached_conv1,
src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
)
src = src + src_conv
src = src + self.feed_forward2(src)
# bypass in the middle of the layer.
src = self.bypass_mid(src_orig, src)
self_attn, cached_val2 = self.self_attn2.streaming_forward(
src,
attn_weights=attn_weights,
cached_val=cached_val2,
left_context_len=left_context_len,
)
src = src + self_attn
src_conv, cached_conv2 = self.conv_module2.streaming_forward(
src,
cache=cached_conv2,
src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
)
src = src + src_conv
src = src + self.feed_forward3(src)
src = self.norm(src)
src = self.bypass(src_orig, src)
return (
src,
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
)
class Zipformer2Encoder(nn.Module):
r"""Zipformer2Encoder is a stack of N encoder layers