Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,12 @@ def __init__(self, cfg: AssemblyEnvCfg, render_mode: str | None = None, **kwargs
)

# Create criterion for dynamic time warping (later used for imitation reward)
self.soft_dtw_criterion = SoftDTW(use_cuda=True, gamma=self.cfg_task.soft_dtw_gamma)
self.soft_dtw_criterion = SoftDTW(use_cuda=True, device=self.device, gamma=self.cfg_task.soft_dtw_gamma)

# Evaluate
if self.cfg_task.if_logging_eval:
self._init_eval_logging()

if self.cfg_task.sample_from != "rand":
self._init_eval_loading()

def _init_eval_loading(self):
eval_held_asset_pose, eval_fixed_asset_pose, eval_success = automate_log.load_log_from_hdf5(
self.cfg_task.eval_filename
)

if self.cfg_task.sample_from == "gp":
self.gp = automate_algo.model_succ_w_gp(eval_held_asset_pose, eval_fixed_asset_pose, eval_success)
elif self.cfg_task.sample_from == "gmm":
self.gmm = automate_algo.model_succ_w_gmm(eval_held_asset_pose, eval_fixed_asset_pose, eval_success)

def _init_eval_logging(self):

self.held_asset_pose_log = torch.empty(
Expand Down Expand Up @@ -246,7 +233,7 @@ def _load_disassembly_data(self):
# offset each trajectory to be relative to the goal
eef_pos_traj.append(curr_ee_traj - curr_ee_goal)

self.eef_pos_traj = torch.tensor(eef_pos_traj, dtype=torch.float32, device=self.device).squeeze()
self.eef_pos_traj = torch.tensor(np.array(eef_pos_traj), dtype=torch.float32, device=self.device).squeeze()

def _get_keypoint_offsets(self, num_keypoints):
"""Get uniformly-spaced keypoints along a line of unit length, centered at 0."""
Expand Down Expand Up @@ -804,28 +791,12 @@ def randomize_held_initial_state(self, env_ids, pre_grasp):
torch.rand((self.num_envs,), dtype=torch.float32, device=self.device)
)

if self.cfg_task.sample_from == "rand":

rand_sample = torch.rand((len(env_ids), 3), dtype=torch.float32, device=self.device)
held_pos_init_rand = 2 * (rand_sample - 0.5) # [-1, 1]
held_asset_init_pos_rand = torch.tensor(
self.cfg_task.held_asset_init_pos_noise, dtype=torch.float32, device=self.device
)
self.held_pos_init_rand = held_pos_init_rand @ torch.diag(held_asset_init_pos_rand)

if self.cfg_task.sample_from == "gp":
rand_sample = torch.rand((self.cfg_task.num_gp_candidates, 3), dtype=torch.float32, device=self.device)
held_pos_init_rand = 2 * (rand_sample - 0.5) # [-1, 1]
held_asset_init_pos_rand = torch.tensor(
self.cfg_task.held_asset_init_pos_noise, dtype=torch.float32, device=self.device
)
held_asset_init_candidates = held_pos_init_rand @ torch.diag(held_asset_init_pos_rand)
self.held_pos_init_rand, _ = automate_algo.propose_failure_samples_batch_from_gp(
self.gp, held_asset_init_candidates.cpu().detach().numpy(), len(env_ids), self.device
)

if self.cfg_task.sample_from == "gmm":
self.held_pos_init_rand = automate_algo.sample_rel_pos_from_gmm(self.gmm, len(env_ids), self.device)
rand_sample = torch.rand((len(env_ids), 3), dtype=torch.float32, device=self.device)
held_pos_init_rand = 2 * (rand_sample - 0.5) # [-1, 1]
held_asset_init_pos_rand = torch.tensor(
self.cfg_task.held_asset_init_pos_noise, dtype=torch.float32, device=self.device
)
self.held_pos_init_rand = held_pos_init_rand @ torch.diag(held_asset_init_pos_rand)

# Set plug pos to assembled state, but offset plug Z-coordinate by height of socket,
# minus curriculum displacement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ class AssemblyTask:
num_eval_trials: int = 100
eval_filename: str = "evaluation_00015.h5"

# Fine-tuning
sample_from: str = "rand" # gp, gmm, idv, rand
num_gp_candidates: int = 1000


@configclass
class Peg8mm(HeldAssetCfg):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class Extraction(DisassemblyTask):
assembly_id = "00015"
assembly_dir = f"{ASSET_DIR}/{assembly_id}/"
disassembly_dir = "disassembly_dir"
num_log_traj = 1000
num_log_traj = 100

fixed_asset_cfg = Hole8mm()
held_asset_cfg = Peg8mm()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ def main():

update_task_param(args.cfg_path, args.assembly_id, args.train, args.log_eval)

bash_command = None
# avoid the warning of low GPU occupancy for SoftDTWCUDA function
bash_command = "NUMBA_CUDA_LOW_OCCUPANCY_WARNINGS=0"
if sys.platform.startswith("win"):
bash_command = "isaaclab.bat -p"
bash_command += " isaaclab.bat -p"
elif sys.platform.startswith("linux"):
bash_command = "./isaaclab.sh -p"
bash_command += " ./isaaclab.sh -p"
if args.train:
bash_command += " scripts/reinforcement_learning/rl_games/train.py --task=Isaac-AutoMate-Assembly-Direct-v0"
bash_command += f" --seed={str(args.seed)} --max_iterations={str(args.max_iterations)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ class _SoftDTWCUDA(Function):
"""

@staticmethod
def forward(ctx, D, gamma, bandwidth):
def forward(ctx, D, device, gamma, bandwidth):
dev = D.device
dtype = D.dtype
gamma = torch.cuda.FloatTensor([gamma])
bandwidth = torch.cuda.FloatTensor([bandwidth])
gamma = torch.tensor([gamma], dtype=torch.float, device=device)
bandwidth = torch.tensor([bandwidth], dtype=torch.float, device=device)

B = D.shape[0]
N = D.shape[1]
Expand Down Expand Up @@ -255,7 +255,7 @@ class _SoftDTW(Function):
"""

@staticmethod
def forward(ctx, D, gamma, bandwidth):
def forward(ctx, D, device, gamma, bandwidth):
dev = D.device
dtype = D.dtype
gamma = torch.Tensor([gamma]).to(dev).type(dtype) # dtype fixed
Expand Down Expand Up @@ -286,10 +286,11 @@ class SoftDTW(torch.nn.Module):
The soft DTW implementation that optionally supports CUDA
"""

def __init__(self, use_cuda, gamma=1.0, normalize=False, bandwidth=None, dist_func=None):
def __init__(self, use_cuda, device, gamma=1.0, normalize=False, bandwidth=None, dist_func=None):
"""
Initializes a new instance using the supplied parameters
:param use_cuda: Flag indicating whether the CUDA implementation should be used
:param device: device to run the soft dtw computation
:param gamma: sDTW's gamma parameter
:param normalize: Flag indicating whether to perform normalization
(as discussed in https://github.com/mblondel/soft-dtw/issues/10#issuecomment-383564790)
Expand All @@ -301,6 +302,7 @@ def __init__(self, use_cuda, gamma=1.0, normalize=False, bandwidth=None, dist_fu
self.gamma = gamma
self.bandwidth = 0 if bandwidth is None else float(bandwidth)
self.use_cuda = use_cuda
self.device = device

# Set the distance function
if dist_func is not None:
Expand Down Expand Up @@ -357,12 +359,12 @@ def forward(self, X, Y):
x = torch.cat([X, X, Y])
y = torch.cat([Y, X, Y])
D = self.dist_func(x, y)
out = func_dtw(D, self.gamma, self.bandwidth)
out = func_dtw(D, self.device, self.gamma, self.bandwidth)
out_xy, out_xx, out_yy = torch.split(out, X.shape[0])
return out_xy - 1 / 2 * (out_xx + out_yy)
else:
D_xy = self.dist_func(X, Y)
return func_dtw(D_xy, self.gamma, self.bandwidth)
return func_dtw(D_xy, self.device, self.gamma, self.bandwidth)


# ----------------------------------------------------------------------------------------------------------------------
Expand Down
Loading