Matmul backward
Background
This is the backward pass for plain matrix multiplication — the gradient that sits underneath every linear layer, attention head, and im2col convolution. It is two lines, and like the linear backward, the shapes alone force the formulas: there is only one dimension-consistent way to combine the available matrices.
Problem statement
Implement matmul_backward(grad_out, A, B). The forward was with , , . Given of shape :
Return (dA, dB).
Input
grad_out—np.ndarrayof shape(M, N): the gradient of the loss w.r.t.C.A—np.ndarrayof shape(M, K): the left operand from the forward pass.B—np.ndarrayof shape(K, N): the right operand from the forward pass.
Output
Returns a tuple (dA, dB) with shapes (M, K) and (K, N).
Examples
Example 1 — hand-checked ()
Input: A = [[1, 2]], B = [[3, 4],
[5, 6]], grad_out = [[1, 1]]
Output: dA = [[7, 11]]
dB = [[1, 1],
[2, 2]]
Explanation: ; .
Example 2 — zero upstream gradient gives zero gradients
Input: grad_out = zeros((2, 2)), any A (2,2), any B (2,2)
Output: dA = zeros((2, 2)), dB = zeros((2, 2))
Explanation: both gradients are linear in grad_out, so a zero upstream gradient yields zero dA and dB.
Constraints
dA = grad_out @ B.T→(M, K);dB = A.T @ grad_out→(K, N).- Shapes are the guardrail: with
grad_out (M,N)andB (K,N), onlygrad_out @ B.Tlands on(M, K)— getting a transpose wrong makes the matmul fail or mismatch. - The diagnostic finite-difference cross-check (perturb each entry of
A, thenB) must match withinatol≈1e-3.
Notes
- Why the transposes. From , the gradient at is — exactly the entry of . The same reasoning gives .
- Same pattern as Linear. Linear backward's
dxanddWare this matmul gradient — a linear layer is justA @ B(input × weight) plus a bias term.
This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Returns two arrays with correct shapes
- •dA matches grad_out @ B.T
- •dB matches A.T @ grad_out
- •Diagnostic: matches finite-difference numerical gradient
- •Zero grad_out -> zero gradients