Matmul backwardMedium

Matmul backward

Background

This is the backward pass for plain matrix multiplication — the gradient that sits underneath every linear layer, attention head, and im2col convolution. It is two lines, and like the linear backward, the shapes alone force the formulas: there is only one dimension-consistent way to combine the available matrices.

Problem statement

Implement matmul_backward(grad_out, A, B). The forward was C=ABC = AB with A:(M,K)A: (M,K), B:(K,N)B: (K,N), C:(M,N)C: (M,N). Given G=L/CG = \partial L/\partial C of shape (M,N)(M,N):

dA=LA=GB    (M,K),dB=LB=AG    (K,N)dA = \frac{\partial L}{\partial A} = G\,B^\top \;\; (M, K), \qquad dB = \frac{\partial L}{\partial B} = A^\top G \;\; (K, N)

Return (dA, dB).

Input

  • grad_outnp.ndarray of shape (M, N): the gradient of the loss w.r.t. C.
  • Anp.ndarray of shape (M, K): the left operand from the forward pass.
  • Bnp.ndarray of shape (K, N): the right operand from the forward pass.

Output

Returns a tuple (dA, dB) with shapes (M, K) and (K, N).

Examples

Example 1 — hand-checked (M=1,K=2,N=2M=1, K=2, N=2)

Input:  A = [[1, 2]], B = [[3, 4],
                           [5, 6]], grad_out = [[1, 1]]
Output: dA = [[7, 11]]
        dB = [[1, 1],
              [2, 2]]

Explanation: dA=GB=[1,1][3546]=[7,11]dA = G\,B^\top = [1,1]\cdot\begin{bmatrix}3&5\\4&6\end{bmatrix} = [7, 11]; dB=AG=[12][1,1]=[1122]dB = A^\top G = \begin{bmatrix}1\\2\end{bmatrix}\cdot[1,1] = \begin{bmatrix}1&1\\2&2\end{bmatrix}.

Example 2 — zero upstream gradient gives zero gradients

Input:  grad_out = zeros((2, 2)), any A (2,2), any B (2,2)
Output: dA = zeros((2, 2)), dB = zeros((2, 2))

Explanation: both gradients are linear in grad_out, so a zero upstream gradient yields zero dA and dB.

Constraints

  • dA = grad_out @ B.T(M, K); dB = A.T @ grad_out(K, N).
  • Shapes are the guardrail: with grad_out (M,N) and B (K,N), only grad_out @ B.T lands on (M, K) — getting a transpose wrong makes the matmul fail or mismatch.
  • The diagnostic finite-difference cross-check (perturb each entry of A, then B) must match within atol≈1e-3.

Notes

  • Why the transposes. From Cij=kAikBkjC_{ij} = \sum_k A_{ik}B_{kj}, the gradient at AikA_{ik} is jGijBkj\sum_j G_{ij}B_{kj} — exactly the (i,k)(i,k) entry of GBG B^\top. The same reasoning gives dB=AGdB = A^\top G.
  • Same pattern as Linear. Linear backward's dx and dW are this matmul gradient — a linear layer is just A @ B (input × weight) plus a bias term.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Returns two arrays with correct shapes
  • dA matches grad_out @ B.T
  • dB matches A.T @ grad_out
  • Diagnostic: matches finite-difference numerical gradient
  • Zero grad_out -> zero gradients