Is the field initialization parallelized? #3141
zhiping0913
started this conversation in
General
Replies: 1 comment
-
|
I tried to jax.jit the get_id_int function, and successfully reduced the time for each call. @jax.jit
def _get_axis_id_int(axis_start,dr,N, pos):
id_float=(pos - axis_start)/dr
id_int=jnp.clip(jnp.round(id_float),0,N-1).astype(jnp.int64)
return id_int
@jax.jit
def get_x_id_int(pos):
return _get_axis_id_int(axis_start=x_axis[0], dr=dx, N=Nx, pos=pos)
@jax.jit
def get_y_id_int(pos):
return _get_axis_id_int(axis_start=y_axis[0], dr=dy, N=Ny, pos=pos)
def initialize_Ex(pos:mp.Vector3):
x=pos.x+x_center
y=pos.y+y_center
Ex=Electric_Field_Ex[get_x_id_int(x), get_y_id_int(y)]
return ExIt still took minutes for the initialization Total time: 479.68 s
File: /start_2D.py
Function: initialize_Ex at line 71
Line # Hits Time Per Hit % Time Line Contents
==============================================================
71 @profile
72 def initialize_Ex(pos:mp.Vector3):
73 9619209 3406122.5 0.4 0.7 x=pos.x+x_center
74 9619209 3357300.6 0.3 0.7 y=pos.y+y_center
75 9619209 468507783.6 48.7 97.7 Ex=Electric_Field_Ex[get_x_id_int(x), get_y_id_int(y)]
76 9619209 4408977.3 0.5 0.9 return Ex Although jax.jit helps in finding the data in the given array, it still cannot vectorize the entire initialization procedure. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
Now I am trying to initialize the fields with a given array (which is obtained from other simulations or calculations) by using
This initialization method works. However, it is very slow!
For example, in this 2D simulation job
(That's not a big size)
it takes hours for the initialization.
We can use line_profiler to see what is happening
The sim.fields.initialize_field calls the initialize_Ey function about Nx*Ny=9600000 times (for every point in the simulation cell), with about 3μs for each call.
The initialization process is not parallelized or even vectorized.
Since the type of the variable of the function initialize_Ey is mp.Vector3 (which is required by sim.fields.initialize_field), it is not easy to vectorize the initialize_Ey function with vectorization tools like jax.vmap.
The question is, how to initialize the fields in a more efficiently method?
Can we do the similar initialization with c++ meep?
Beta Was this translation helpful? Give feedback.
All reactions