mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-05-07 02:03:55 -05:00
[GH-ISSUE #744] Chapter 4 Notes on RNN training #5524
Reference in New Issue
Block a user
Delete Branch "%!s()"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Originally created by @jasonlyik on GitHub (Mar 11, 2025).
Original GitHub issue: https://github.com/harvard-edge/cs249r_book/issues/744
Additions to Section 4.4 Recurrent Neural Networks: Sequential Pattern Processing
Not sure if this is too much information for this book, but a big factor behind RNNs disappearing and now having a resurgence is the difficulty of parallelizing training.
For a given sequence, the RNN state and output is dependent on the previous state, which means that a classic non-linear RNN must sequentially calculate every output and every token's gradient (BPTT), reducing parallelization. The effect of this could be mitigated if batch size was blown up, utilizing available parallelization over batch while sequence is sequential, but this is not effective for training. Note that ideally, we would want a parallel tensor operation that one-shot calculates the entire output sequence of a batch, so that we can pass the whole thing to the next (usually stateless) layer.
RNNs disappeared arguably because the sequential part of BPTT made their training untenable on GPUs, and they were replaced with stateless sequence models (Transformers) which can be effectively parallelized.
Recently RNNs have re-appeared in SSM and LRU-type architectures which remove non-linearities from the RNN section. By making the RNN completely linear, the evolution of the state and output becomes closed-form and can therefore be efficiently implemented in parallel as a convolution, FFT, or parallel scan. This has made training tractable and resulted in competitive performance of stateful architectures again.
Note that non-linearities are still present in a MLP/GLU channel-mixing layer following each RNN block. Recent ML research has shown that this amount of non-linearity is enough.
@profvjreddi commented on GitHub (Mar 14, 2025):
Thanks @jasonlyik. Will make a pass on incorporating this stuff in soon.
@profvjreddi commented on GitHub (Mar 28, 2025):
@18jeffreyma could you see if this makes sense to incorporate.