Np.where Checking Also For Subelements In Multidimensional Arrays
Solution 1:
[i for i, e inenumerate(x) if (e == z).all(1).any()]
Test case:
x = np.array([[0,1,2,3], [4,0,6,9], [4,0,6,19]])
z= np.array([[4,0,6,9], [0,1,2,3]])
[i for i, e in enumerate(x) if (e == z).all(1).any()]
Output:
[0, 1]
Solution 2:
Where simply returns the indices of your condition - here it's element wise equal
Answer
You can find the duplicates using vectorized operations:
duplicates = (x[:, None] == z).all(-1).any(-1)
Get Values
To get the duplicates values use masking
x[duplicates]
in this example:
duplicates = [True False]
x[duplicates] = [[0, 1, 2, 3]]
Logic
- expanding the array
[:, None]
- find only full row matches
all(-1)
- return rows that have at least one match
any(-1)
Solution 3:
Man I haven't had a chance to link to this answer since np.unique
added an axis
parameter. Credit to @Jaime
vview = lambda a: np.ascontiguousarray(a).view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))
Basically, that takes the "rows" of your matrix and turns them into a 1-d array of views on the raw datastream of the rows. This lets you compare rows as if they were single values.
Then it's fairly simple:
print(np.where(vview(x) == vview(z).T))
(array([0], dtype=int64), array([0], dtype=int64))
Representing that the 1st row of x
matches the first row of z
If you only want to know if rows of x
are in rows of z
:
print(np.where(np.isin(vview(x), vview(z)).squeeze()))
(array([0], dtype=int64),)
Checking times compared to @mujjiga on big arrays:
x = np.random.randint(10, size = (1000, 4))
z = np.random.randint(10, size = (1000, 4))
%timeit np.where(np.isin(vview(x), vview(z)).squeeze())
365 µs ± 13.8 µs per loop (mean ± std. dev. of7 runs, 1000 loops each)
%timeit [i for i, e in enumerate(x) if (e == z).all(1).any()] # @mujjiga21.3 ms ± 1.28 ms per loop (mean ± std. dev. of7 runs, 10 loops each)
%timeit np.where((x[:, None] == z).all(-1).any(-1)) # @orgoro20 ms ± 767 µs per loop (mean ± std. dev. of7 runs, 100 loops each)
So about a 60x speedup over looping and slicing, probably due to quick short-circuiting and only comparing 1/4 the values
Solution 4:
Well, for 2D arrays, something like the following might be useful. I think you'd have to be careful with check ee==0
in floating point arithmetic.
import numpy as np
aa = np.arange(16).reshape(4,4)
# we are trying to find the row in aa which is equal to bb
bb = np.asarray([0,1,2,3])
cc = bb[None,:]
dd = aa - cc
ee = np.linalg.norm(dd,axis=1)
idx = np.where(ee==0)
Post a Comment for "Np.where Checking Also For Subelements In Multidimensional Arrays"