Vanilla RNN cell forwardMedium

Vanilla RNN cell forward

Background

The vanilla RNN cell is one timestep of the recurrent computation: blend the current input with the previous hidden state and squash with tanh. By itself it's trivial, but composed across hundreds of timesteps it produces the vanishing-gradient pathology that LSTMs and GRUs were built to mitigate — and that transformers sidestepped entirely with attention. It remains a core interview question and a useful reference point for understanding what attention solved.

Problem statement

Implement rnn_cell_forward(x, h_prev, W_x, W_h, b), a single RNN step:

ht=tanh ⁣(xtWx+ht1Wh+b)h_t = \tanh\!\big(x_t\,W_x + h_{t-1}\,W_h + b\big)

Two matmuls (input-to-hidden and hidden-to-hidden), a bias add, and a tanh.

Input

  • x(B, input_size): the input at this timestep.
  • h_prev(B, hidden_size): the hidden state from the previous step.
  • W_x(input_size, hidden_size): input-to-hidden weights.
  • W_h(hidden_size, hidden_size): hidden-to-hidden weights.
  • b(hidden_size,): the bias.

Output

Returns h_next of shape (B, hidden_size) — the new hidden state.

Examples

Example 1 — the bias broadcasts across the batch

Input:  x = 0 (B, I), h_prev = 0 (B, H), W_x = 0, W_h = 0, b = [1, -1, 0.5, -0.5]
Output: every row = tanh(b) ≈ [0.762, -0.762, 0.462, -0.462]

Explanation: with zero input and zero previous state the pre-activation is just b, so each of the B batch rows is tanh(b) — the (H,) bias broadcasts across the batch dimension.

Example 2 — tanh bounds the output

Input:  large-magnitude x, h_prev, W_x, W_h (e.g. scaled by 100 / 10)
Output: every entry of h_next lies in (-1, 1)

Explanation: however large the pre-activation, tanh clips it into (1,1)(-1, 1). That bounded range is what stops the hidden state from exploding as it's fed back through W_h each step — but it also makes the gradient 1tanh2(z)01 - \tanh^2(z) \approx 0 for large z|z|, the saturation behind vanishing gradients at depth.

Constraints

  • Compute tanh(x @ W_x + h_prev @ W_h + b); the two matmuls produce (B, H) and add elementwise, with b broadcasting across the batch.
  • Output is always bounded in (1,1)(-1, 1) regardless of input magnitude.
  • Identity-recurrence sanity check: with W_x=0, W_h=I, b=0 and small h_prev, the output \approx h_prev (since tanh(small) ≈ small).

Notes

  • Why tanh, not ReLU. The hidden state is multiplied by W_h and re-fed every step; an unbounded ReLU would explode through repeated multiplication, so the bounded (1,1)(-1, 1) range of tanh (or sigmoid) is the classic choice.
  • Historical arc. Vanilla RNNs gave way to LSTMs (~2014) and then to transformers (~2018); this cell is the baseline the attention mechanism ultimately replaced.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Output shape is (B, hidden_size)
  • Diagnostic: matches the explicit formula
  • Output values are bounded in (-1, 1) — tanh saturation
  • Bias broadcasts correctly across the batch dimension
  • Identity-recurrence behaviour: with W_x=0, W_h=I, b=0, h doesn't change much through small h_prev