Skip to content Skip to sidebar Skip to footer

Numpy ":" Operator Broadcasting Issues

In the following code I have written 2 methods that theoretically(in my mind) should do the same thing. Unfortunately they don't, I am unable to find out why they don't do the same

Solution 1:

As indicated, in principle you cannot operate multiple times over the same element in a single operation, due to how buffering works in NumPy. For that purpose there is the at function, which can be used on about any standard NumPy function (add, subtract, etc.). For your case, you can do:

import numpy as np

dW = np.zeros((20, 10))
y = [1for _ inrange(100)]
X =  np.ones((100, 20))
# at modifies in place dW, does not return a new array, (slice(None), y), X.T)

Solution 2:

This is a column-wise version of this question.

The answer there can be adapted to work column-wise as follows:

Approach 1: np.<ufunc>.at

>>>, (slice(None), y), X.T)

Approach 2: np.bincount

>>>m, n = dW.shape>>>dW -= np.bincount(np.add.outer(np.arange(m) * n, y).ravel(), (X.T).ravel(), dW.size).reshape(m, n)

Please note that the bincount based solution - even though it involves more steps - is faster by a factor of ~6.

>>>from timeit import repeat>>>kwds = dict(globals=globals(), number=5000)>>>>>>repeat(', (slice(None), y), X.T);, (slice(None), y), X.T)', **kwds)
[1.590626839082688, 1.5769231889862567, 1.5802007300080732]
>>>repeat('_= dW; _ -= np.bincount(np.add.outer(np.arange(m) * n, y).ravel(), (X.T).ravel(), dW.size).reshape(m, n); _ += np.bincount(np.add.outer(np.arange(m) * n, y).ravel(), (X.T).ravel(), dW.size).reshape(m, n)', **kwds)
[0.2582490430213511, 0.25572817400097847, 0.25478115503210574]

Solution 3:

Option 1:

for i in range(len(y)):
  dW[:, y[i]] -=  X[i]

This works because you are looping through and updating value which was updated last time.

Option 2:

dW[:, [1,1,1,1,....1,1,1]] -=  [[1,1,1,1...1],

It does not work because update happens to 1st index at same time in parallel not in serial way. Initially all are 0 so subtracting results in -1.

Post a Comment for "Numpy ":" Operator Broadcasting Issues"