# Crushing MATLAB Loop Runtimes with BSXFUN

Gallagher PryorBenchmarks 1 Comment

One of the slowest blocks of code that inflate runtimes in MATLAB are for/while loops. In this blog post, I’m going to talk about a little known way of crushing MATLAB loop runtimes for many commonplace use cases by utilizing one of the most amazingly underrated and unknown functions in MATLAB’s repertoire: bsxfun. Using this function, one can break seemingly iterative code into clean, vectorized, snippets that beat the socks off even MATLAB’s JIT engine. Better still, Jacket fully supports bsxfun meaning that if you thought a vectorized loop was fast, you haven’t seen anything, yet. Also, in the end, a loop represented using bsxfun is just good programming practice. As we’ll see, the technique I’m going to describe is more restrictive than a general for-loop and thus represents your computation more concisely, which opens the door to more optimization of your algorithm and less chances for bugs in the long run.

The driver of this blog post will be the following example where we compute a distance matrix from two matrices `a` and `b`. We expect both `a` and `b` to have size `[3, N]`, representing `N` length 3 vectors, say, positions in 3D. We wish to compute a matrix `C` whose `i,j`‘th element is the Euclidean distance between the `i`‘th vector in `a` and the `j`‘th vector in `b`,

```% example data
a = rand(3, 2000); b = rand(3,2000);
sz_a = size(a); sz_b = size(b);

% (non-JIT for loop version)
tic
C = zeros(sz_b(2), sz_a(2));
for i = 1:sz_b(2)
for j = 1:sz_a(2)
C(i,j) = norm(a(:,j) - b(:,i));
end
end
t1 = toc```

On my system, this times in at 7.46 seconds.

Notice that norm is utilized in the loop. The MATLAB JIT is not smart enough to handle this more exotic function, so lets see what happens when we move to straight arithmetic,

```tic
C = zeros(sz_b(2), sz_a(2));
for i = 1:sz_b(2)
for j = 1:sz_a(2)
C(i,j) = sqrt((a(1,j) - b(1,i)).^2 + (a(2,j) - b(2,i)).^2 + ...
(a(3,j) - b(3,i)).^2);
end
end
t2 = toc```

This times in at 0.3657 seconds, a factor 20x improvement!

This time, the MATLAB JIT got to kick in due to the exclusive presence of arithmetic. Now, let’s flatten this loop using bsxfun, see how much faster it is, and understand how it works.

```tic
c1 = bsxfun(@minus, a(1,:), b(1,:)');
c2 = bsxfun(@minus, a(2,:), b(2,:)');
c3 = bsxfun(@minus, a(3,:), b(3,:)');
C = sqrt(c1.^2 + c2.^2 + c3.^2);
t3 = toc```

This times in at 0.166 seconds, a factor 45x improvement!

Lets go through this line by line,

• LINE 2: c1 becomes an N by N matrix, each element (i,j), the difference between`A(1,i)` and `B(1,j)`
• LINE 3: c2 becomes an N by N matrix, each element (i,j), the difference between`A(2,i)` and `B(2,j)`
• LINE 4: c3 becomes an N by N matrix, each element (i,j), the difference between`A(3,i)` and `B(3,j)`
• LINE 5: component-wise arithmetic utilized to combine c1, c2, and c3 into the Euclidean distance.

The key to understanding lines 2, 3, and 4 is to understand bsxfun. According to MATLAB, bsxfun performs component-wise arithmetic with singleton expansion enabled. What this means is that bsxfun simply carries out component by component arithmetic (+, -, .*) just like normal MATLAB arithmetic such as `A + B` or `A .* B`. However, the upshot bsxfun offers is when the dimensions of `A` and `B` are not ordinarily compatible, it expands `A` or `B` out so that they are compatible. bsxfun then performs arithmetic with this new data that has compatible dimensions.

For example, bsxfun is implicitly carrying out the following operations to accomplish the work on line 2,

```a_ = repmat(a(1,:), N, 1);
b_ = repmat(b(1,:)', 1, N);
c1 = a_ - b_;```

But why waste memory, time, and thought when bsxfun can handle this for you?

Taking a step further and plugging into Jacket (please excuse the GEVAL’s and GSYNC’s – – they are necessary for timing correctly) with a Tesla C1060,

```a = gsingle(a); b = gsingle(b);
geval(a); geval(b);
tic
c1 = bsxfun(@minus, a(1,:), b(1,:)');
c2 = bsxfun(@minus, a(2,:), b(2,:)');
c3 = bsxfun(@minus, a(3,:), b(3,:)');
C = sqrt(c1.^2 + c2.^2 + c3.^2);
geval(C); gsync;
t4 = toc```

This times in at 0.0288 seconds, a factor 259x improvement!