Replies: 1 comment 2 replies
-
|
I think my intuition of non-contiguous memory was correct, because I found a pretty hacky but surprisingly fast way of dealing with this. out = [0,0,0,0,0]
for i,Mi in enumerate(M):
for j,Mij in enumerate(Mi):
#Mij = (512*512)
out[i] += Mij*L[j]
return [x.T for x in out]Luckily, the output of Anyone looking at that code will probably wonder wth I am doing there but it works and is fast. I don't know if there is already an elegant way of circumventing the issues I had and how viable the idea would be, but something like a lambda-allocator would be very nice. Basically, a function that takes a lambda expression as input and populates a contiguous memory tensor whose entry at f = lambda i,j: (i==j)*1.0
I = lambda_allocator(f, bounds=[(0,10), (0, 10)]) #equivalent to jnp.eye(10)
f = lambda i,j: jnp.zeros((100,))
M = lambda_allocator(f, bounds=[(0,13), (0, 10)]) #equivalent to jnp.zeros((13,10,100)) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I am experiencing undesired behavior which leads a significant slow down.
I am not sure If I am on the right track of solving this issue, but I think it has something to do with the memory management of my function.
I have a function that gets called 100s of times. It is very important for my application that this function is evaluated quickly. Essentially what the function computes is a dot-product of a matrix that I need to generate and another matrix that is provided, thus essentially
M(r,d)@ L.MandLhave fairly large leading dimensions ofN = 250,000 to 1,000,000(butNis fixed and does not vary within a run).randdhave dimensions(3, N)which in turn meansMhas shape(N,35,5)andLhas dimensions(N,4,35). The problem that I have run into is that there is no "easy" way to expressMin terms ofrandd(i.e. as a sequence of matmuls or something). I am currently generatingMfrom a list of submatrices like so:This is very fast! 1.2ms kind of what I need but I now have a list of arrays that I unfortunately cannot shove into
dot.As soon as I turn this sequence of arrays into an
jnp.array(jnp.asarraydoes not help) and transpose it so I can useM@L, i.e. if I doreturn jnp.array(M).transpose((2,0,1))the function becomes 40x slower.If tried various things, like using
einsumto avoid thetransposebut all of these strategies have failed. I believe my problem is that when I generate the tuple of tuplesM, the respective blocks of memory are fragmented and as soon as I calltransposeXLA starts copying memory around to make thearrayMcontiguous.Am I interpreting this behavior correctly?
Is there a way of directly generating the submatrices at the "correct location in terms of memory" such that it already is contiguous and does not require copying?
Beta Was this translation helpful? Give feedback.
All reactions