Vanilla RNN cell forward
Background
The vanilla RNN cell is one timestep of the recurrent computation: blend the current input with the previous hidden state and squash with tanh. By itself it's trivial, but composed across hundreds of timesteps it produces the vanishing-gradient pathology that LSTMs and GRUs were built to mitigate — and that transformers sidestepped entirely with attention. It remains a core interview question and a useful reference point for understanding what attention solved.
Problem statement
Implement rnn_cell_forward(x, h_prev, W_x, W_h, b), a single RNN step:
Two matmuls (input-to-hidden and hidden-to-hidden), a bias add, and a tanh.
Input
x—(B, input_size): the input at this timestep.h_prev—(B, hidden_size): the hidden state from the previous step.W_x—(input_size, hidden_size): input-to-hidden weights.W_h—(hidden_size, hidden_size): hidden-to-hidden weights.b—(hidden_size,): the bias.
Output
Returns h_next of shape (B, hidden_size) — the new hidden state.
Examples
Example 1 — the bias broadcasts across the batch
Input: x = 0 (B, I), h_prev = 0 (B, H), W_x = 0, W_h = 0, b = [1, -1, 0.5, -0.5]
Output: every row = tanh(b) ≈ [0.762, -0.762, 0.462, -0.462]
Explanation: with zero input and zero previous state the pre-activation is just b, so each of the B batch rows is tanh(b) — the (H,) bias broadcasts across the batch dimension.
Example 2 — tanh bounds the output
Input: large-magnitude x, h_prev, W_x, W_h (e.g. scaled by 100 / 10)
Output: every entry of h_next lies in (-1, 1)
Explanation: however large the pre-activation, tanh clips it into . That bounded range is what stops the hidden state from exploding as it's fed back through W_h each step — but it also makes the gradient for large , the saturation behind vanishing gradients at depth.
Constraints
- Compute
tanh(x @ W_x + h_prev @ W_h + b); the two matmuls produce(B, H)and add elementwise, withbbroadcasting across the batch. - Output is always bounded in regardless of input magnitude.
- Identity-recurrence sanity check: with
W_x=0,W_h=I,b=0and smallh_prev, the outputh_prev(sincetanh(small) ≈ small).
Notes
- Why tanh, not ReLU. The hidden state is multiplied by
W_hand re-fed every step; an unbounded ReLU would explode through repeated multiplication, so the bounded range of tanh (or sigmoid) is the classic choice. - Historical arc. Vanilla RNNs gave way to LSTMs (~2014) and then to transformers (~2018); this cell is the baseline the attention mechanism ultimately replaced.
This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Output shape is (B, hidden_size)
- •Diagnostic: matches the explicit formula
- •Output values are bounded in (-1, 1) — tanh saturation
- •Bias broadcasts correctly across the batch dimension
- •Identity-recurrence behaviour: with W_x=0, W_h=I, b=0, h doesn't change much through small h_prev