LoRA forward passEasy

LoRA forward pass

Background

LoRA (Low-Rank Adaptation) fine-tunes a large model cheaply by freezing the original weight matrix W0W_0 and learning only a small low-rank update ΔW=BA\Delta W = BA. Because ARr×inA\in\mathbb R^{r\times \text{in}} and BRout×rB\in\mathbb R^{\text{out}\times r} with rr tiny, the number of trainable parameters drops by orders of magnitude, yet the adapted layer can still shift behavior significantly.

Problem statement

Implement lora_forward(x, W0, A, B, alpha) for the adapted linear layer:

Weff=W0+αrBA,y=xWeffW_{\text{eff}} = W_0 + \frac{\alpha}{r}\,B A, \qquad y = x\, W_{\text{eff}}^\top

where rr is the LoRA rank (A.shape[0]) and α/r\alpha/r is the scaling factor.

Input

  • x — input, shape (in,) or (batch, in).
  • W0 — frozen base weights, shape (out, in).
  • A — LoRA down-projection, shape (r, in).
  • B — LoRA up-projection, shape (out, r).
  • alphafloat, LoRA scaling numerator.

Output

An np.ndarray of shape (out,) for a 1-D input or (batch, out) for a batched input.

Examples

Example 1

Input:  x = [1, 0], W0 = [[1,0],[0,1]], A = [[1,0]], B = [[1],[0]], alpha = 2
Output: [3, 0]

Explanation: r=1r=1, scaling =2/1=2=2/1=2. BA=[[1,0],[0,0]]BA=[[1,0],[0,0]], so Weff=[[3,0],[0,1]]W_{\text{eff}}=[[3,0],[0,1]]. Then y=[1,0]Weff=[3,0]y = [1,0]\,W_{\text{eff}}^\top = [3,0].

Constraints

  • The scaling factor is α/r\alpha / r where r=A.shape[0]r = A.shape[0].
  • ΔW=BA\Delta W = B A has the same shape as W0W_0 (out × in).
  • Support both a single vector and a batch of inputs.

Notes

  • At inference WeffW_{\text{eff}} can be folded into a single matrix, so LoRA adds zero latency once merged.
  • Only AA and BB are trained; keeping W0W_0 frozen is what makes adapting a 70B model feasible on modest hardware.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Reference example
  • Zero B reduces to the base layer
  • Scaling is alpha / r
  • Batched input
  • Matches explicit W0 + scaling*BA