Skip to content

Commit b048ad6

Browse files
committed
feat(rfl): unify rfl and rflrelax training schedules as part of run.py
1 parent 4c8806d commit b048ad6

File tree

8 files changed

+116
-98
lines changed

8 files changed

+116
-98
lines changed

include/neural-graphics-primitives/fused_kernels/render_nerf.cuh

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919

2020
using namespace ngp;
2121

22-
__launch_bounds__(128, 4)
23-
__global__ void render_nerf(
22+
__launch_bounds__(128, 4) __global__ void render_nerf(
2423
uint32_t sample_index,
2524
ivec2 resolution,
2625
vec2 focal_length,
@@ -65,7 +64,8 @@ __global__ void render_nerf(
6564

6665
vec2 pixel_offset = ld_random_pixel_offset(snap_to_pixel_centers ? 0 : sample_index);
6766
vec2 uv = vec2{(float)x + pixel_offset.x, (float)y + pixel_offset.y} / vec2(resolution);
68-
mat4x3 camera = get_xform_given_rolling_shutter({camera_matrix0, camera_matrix1}, rolling_shutter, uv, ld_random_val(sample_index, idx * 72239731));
67+
mat4x3 camera =
68+
get_xform_given_rolling_shutter({camera_matrix0, camera_matrix1}, rolling_shutter, uv, ld_random_val(sample_index, idx * 72239731));
6969

7070
Ray ray = uv_to_ray(
7171
sample_index,
@@ -108,7 +108,9 @@ __global__ void render_nerf(
108108
vec3 pos = cam_pos;
109109

110110
if (alive) {
111-
t = if_unoccupied_advance_to_next_occupied_voxel(t, cone_angle, ray, idir, density_grid, min_mip, max_mip, render_aabb, render_aabb_to_local);
111+
t = if_unoccupied_advance_to_next_occupied_voxel(
112+
t, cone_angle, ray, idir, density_grid, min_mip, max_mip, render_aabb, render_aabb_to_local
113+
);
112114
alive &= t < MAX_DEPTH();
113115
if (alive) {
114116
pos = ray(t);

include/neural-graphics-primitives/fused_kernels/train_nerf.cuh

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ __global__ void train_nerf(
7878
float depth_supervision_lambda,
7979
float near_distance,
8080

81-
uint32_t training_step,
82-
ETrainMode training_mode,
83-
uint32_t rfl_warmup_steps
81+
ETrainMode training_mode
8482
) {
8583
const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
8684

@@ -222,6 +220,7 @@ __global__ void train_nerf(
222220
color += vec4(rgb * weight, weight);
223221

224222
loss_bg += weight * loss_and_gradient(rgbtarget, rgb, loss_type).loss;
223+
225224
hitpoint += weight * pos;
226225

227226
if (1.0f - color.a < EPSILON || j >= NERF_STEPS()) {
@@ -233,7 +232,7 @@ __global__ void train_nerf(
233232
hitpoint /= color.a;
234233

235234
uint32_t numsteps = j;
236-
uint32_t base = atomicAdd(numsteps_counter, numsteps); // first entry in the array is a counter
235+
uint32_t base = atomicAdd(numsteps_counter, numsteps); // first entry in the array is a counter
237236
numsteps = min(max_samples - min(max_samples, base), numsteps);
238237
bool can_write = numsteps > 0;
239238

@@ -245,8 +244,8 @@ __global__ void train_nerf(
245244
if (can_write) {
246245
ray_idx = atomicAdd(ray_counter, 1);
247246
ray_indices_out[ray_idx] = i;
248-
numsteps_out[ray_idx*2+0] = numsteps;
249-
numsteps_out[ray_idx*2+1] = base;
247+
numsteps_out[ray_idx * 2 + 0] = numsteps;
248+
numsteps_out[ray_idx * 2 + 1] = base;
250249
}
251250

252251
if (1.0f - color.a >= EPSILON) {
@@ -258,7 +257,8 @@ __global__ void train_nerf(
258257
LossAndGradient lg = loss_and_gradient(rgbtarget, color.rgb(), loss_type);
259258
lg.loss /= img_pdf * uv_pdf;
260259

261-
float target_depth = ray_length * ((depth_supervision_lambda > 0.0f && metadata[img].depth) ? read_depth(uv, resolution, metadata[img].depth) : -1.0f);
260+
float target_depth = ray_length *
261+
((depth_supervision_lambda > 0.0f && metadata[img].depth) ? read_depth(uv, resolution, metadata[img].depth) : -1.0f);
262262
LossAndGradient lg_depth = loss_and_gradient(vec3(target_depth), vec3(depth), depth_loss_type);
263263
float depth_loss_gradient = target_depth > 0.0f ? depth_supervision_lambda * lg_depth.gradient.x : 0;
264264

@@ -286,25 +286,28 @@ __global__ void train_nerf(
286286

287287
if (sharpness_data && aabb.contains(hitpoint)) {
288288
ivec2 sharpness_pos = clamp(ivec2(uv * vec2(sharpness_resolution)), 0, sharpness_resolution - 1);
289-
float sharp = sharpness_data[img * product(sharpness_resolution) + sharpness_pos.y * sharpness_resolution.x + sharpness_pos.x] + 1e-6f;
289+
float sharp = sharpness_data[img * product(sharpness_resolution) + sharpness_pos.y * sharpness_resolution.x + sharpness_pos.x] +
290+
1e-6f;
290291

291292
// The maximum value of positive floats interpreted in uint format is the same as the maximum value of the floats.
292-
float grid_sharp = __uint_as_float(atomicMax((uint32_t*)&cascaded_grid_at(hitpoint, sharpness_grid, mip_from_pos(hitpoint, max_mip)), __float_as_uint(sharp)));
293+
float grid_sharp = __uint_as_float(
294+
atomicMax((uint32_t*)&cascaded_grid_at(hitpoint, sharpness_grid, mip_from_pos(hitpoint, max_mip)), __float_as_uint(sharp))
295+
);
293296
grid_sharp = fmaxf(sharp, grid_sharp); // atomicMax returns the old value, so compute the new one locally.
294297

295298
mean_loss *= fmaxf(sharp / grid_sharp, 0.01f);
296299
}
297300

298-
deposit_val(idx.x, idx.y, (1 - weight.x) * (1 - weight.y) * mean_loss);
299-
deposit_val(idx.x+1, idx.y, weight.x * (1 - weight.y) * mean_loss);
300-
deposit_val(idx.x, idx.y+1, (1 - weight.x) * weight.y * mean_loss);
301-
deposit_val(idx.x+1, idx.y+1, weight.x * weight.y * mean_loss);
301+
deposit_val(idx.x, idx.y, (1 - weight.x) * (1 - weight.y) * mean_loss);
302+
deposit_val(idx.x + 1, idx.y, weight.x * (1 - weight.y) * mean_loss);
303+
deposit_val(idx.x, idx.y + 1, (1 - weight.x) * weight.y * mean_loss);
304+
deposit_val(idx.x + 1, idx.y + 1, weight.x * weight.y * mean_loss);
302305
}
303306

304307
loss_scale /= n_rays;
305308

306309
const float output_l2_reg = rgb_activation == ENerfActivation::Exponential ? 1e-4f : 0.0f;
307-
const float output_l1_reg_density = 0.0f;// *mean_density_ptr < NERF_MIN_OPTICAL_THICKNESS() ? 1e-4f : 0.0f;
310+
const float output_l1_reg_density = 0.0f; // *mean_density_ptr < NERF_MIN_OPTICAL_THICKNESS() ? 1e-4f : 0.0f;
308311

309312
// now do it again computing gradients
310313
vec4 color2 = vec4(0.0f);
@@ -369,12 +372,13 @@ __global__ void train_nerf(
369372
continue;
370373
}
371374

372-
coords_out(j-1)->copy(*(NerfCoordinate*)&nerf_in[0], coords_out.stride_in_bytes);
375+
coords_out(j - 1)->copy(*(NerfCoordinate*)&nerf_in[0], coords_out.stride_in_bytes);
373376
if (max_level_rand_training) {
374-
max_level_ptr[j-1] = max_level;
377+
max_level_ptr[j - 1] = max_level;
375378
}
376379

377-
// we know the suffix of this ray compared to where we are up to. note the suffix depends on this step's alpha as suffix = (1-alpha)*(somecolor), so dsuffix/dalpha = -somecolor = -suffix/(1-alpha)
380+
// we know the suffix of this ray compared to where we are up to. note the suffix depends on this step's alpha as suffix =
381+
// (1-alpha)*(somecolor), so dsuffix/dalpha = -somecolor = -suffix/(1-alpha)
378382
const vec3 suffix = color.rgb() - color2.rgb();
379383

380384
float density_derivative = network_to_density_derivative(float(local_network_output[3]), density_activation);
@@ -383,17 +387,13 @@ __global__ void train_nerf(
383387

384388
vec3 dloss_by_drgb;
385389
float dloss_by_dmlp;
386-
if (training_mode == ETrainMode::Rfl && training_step < rfl_warmup_steps) {
387-
training_mode = ETrainMode::Nerf; // Warm up training
388-
}
390+
389391
if (training_mode == ETrainMode::Rfl) {
390392
// Radiance field loss
391393
LossAndGradient local_lg = loss_and_gradient(rgbtarget, rgb, loss_type);
392394
loss_bg2 += weight * local_lg.loss;
393395
dloss_by_drgb = weight * local_lg.gradient;
394-
dloss_by_dmlp = density_derivative * (
395-
dt * sum(T * local_lg.loss - (loss_bg - loss_bg2) + depth_supervision)
396-
);
396+
dloss_by_dmlp = density_derivative * (dt * sum(T * local_lg.loss - (loss_bg - loss_bg2) + depth_supervision));
397397
} else if (training_mode == ETrainMode::RflRelax) {
398398
// In-between volume reconstruction and surface reconstruction.
399399
// This is different from the relaxation in the paper, but is much simpler and also promotes surfaces.
@@ -402,32 +402,33 @@ __global__ void train_nerf(
402402
LossAndGradient local_lg = loss_and_gradient(rgbtarget, rgb_lerp, loss_type);
403403

404404
dloss_by_drgb = weight * local_lg.gradient;
405-
dloss_by_dmlp = density_derivative * (
406-
dt * (dot(local_lg.gradient, T * rgb - suffix) + depth_supervision)
407-
);
405+
dloss_by_dmlp = density_derivative * (dt * (dot(local_lg.gradient, T * rgb - suffix) + depth_supervision));
408406
} else {
409407
// The original NeRF loss
410408
dloss_by_drgb = weight * lg.gradient;
411-
dloss_by_dmlp = density_derivative * (
412-
dt * (dot(lg.gradient, T * rgb - suffix) + depth_supervision)
413-
);
409+
dloss_by_dmlp = density_derivative * (dt * (dot(lg.gradient, T * rgb - suffix) + depth_supervision));
414410
}
415411

416412
tvec<network_precision_t, 4> local_dL_doutput;
417413

418414
// chain rule to go from dloss/drgb to dloss/dmlp_output
419-
local_dL_doutput[0] = loss_scale * (dloss_by_drgb.x * network_to_rgb_derivative(local_network_output[0], rgb_activation) + fmaxf(0.0f, output_l2_reg * (float)local_network_output[0])); // Penalize way too large color values
420-
local_dL_doutput[1] = loss_scale * (dloss_by_drgb.y * network_to_rgb_derivative(local_network_output[1], rgb_activation) + fmaxf(0.0f, output_l2_reg * (float)local_network_output[1]));
421-
local_dL_doutput[2] = loss_scale * (dloss_by_drgb.z * network_to_rgb_derivative(local_network_output[2], rgb_activation) + fmaxf(0.0f, output_l2_reg * (float)local_network_output[2]));
422-
423-
//static constexpr float mask_supervision_strength = 1.f; // we are already 'leaking' mask information into the nerf via the random bg colors; setting this to eg between 1 and 100 encourages density towards 0 in such regions.
424-
//dloss_by_dmlp += (texsamp.a<0.001f) ? mask_supervision_strength * weight : 0.f;
425-
426-
local_dL_doutput[3] =
427-
loss_scale * dloss_by_dmlp +
428-
(float(local_network_output[3]) < 0.0f ? -output_l1_reg_density : 0.0f) +
415+
local_dL_doutput[0] = loss_scale *
416+
(dloss_by_drgb.x * network_to_rgb_derivative(local_network_output[0], rgb_activation) +
417+
fmaxf(0.0f, output_l2_reg * (float)local_network_output[0])); // Penalize way too large color values
418+
local_dL_doutput[1] = loss_scale *
419+
(dloss_by_drgb.y * network_to_rgb_derivative(local_network_output[1], rgb_activation) +
420+
fmaxf(0.0f, output_l2_reg * (float)local_network_output[1]));
421+
local_dL_doutput[2] = loss_scale *
422+
(dloss_by_drgb.z * network_to_rgb_derivative(local_network_output[2], rgb_activation) +
423+
fmaxf(0.0f, output_l2_reg * (float)local_network_output[2]));
424+
425+
// static constexpr float mask_supervision_strength = 1.f; // we are already 'leaking' mask information into the nerf via the random
426+
// bg colors; setting this to eg between 1 and 100 encourages density towards 0 in such regions. dloss_by_dmlp +=
427+
// (texsamp.a<0.001f) ? mask_supervision_strength * weight : 0.f;
428+
429+
local_dL_doutput[3] = loss_scale * dloss_by_dmlp + (float(local_network_output[3]) < 0.0f ? -output_l1_reg_density : 0.0f) +
429430
(float(local_network_output[3]) > -10.0f && local_depth < near_distance ? 1e-4f : 0.0f);
430-
;
431+
;
431432

432433
*(tvec<network_precision_t, 4>*)dloss_doutput = local_dL_doutput;
433434
dloss_doutput += padded_output_width;

include/neural-graphics-primitives/testbed.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,6 @@ class Testbed {
820820
int view = 0;
821821

822822
ETrainMode train_mode = ETrainMode::RflRelax;
823-
int rfl_warmup_steps = 1000;
824823

825824
float depth_supervision_lambda = 0.f;
826825

scripts/run.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ def parse_args():
4040
parser.add_argument("--test_transforms", default="", help="Path to a nerf style transforms json from which we will compute PSNR.")
4141
parser.add_argument("--near_distance", default=-1, type=float, help="Set the distance from the camera at which training rays start for nerf. <0 means use ngp default")
4242
parser.add_argument("--exposure", default=0.0, type=float, help="Controls the brightness of the image. Positive numbers increase brightness, negative numbers decrease it.")
43+
4344
parser.add_argument("--train_mode", default="", type=str, help="The training mode to use. Can be 'nerf', 'rfl', 'rfl_relax'. If not specified, the default mode will be used.")
4445
parser.add_argument("--rfl_warmup_steps", type=int, default=1000, help="Number of steps to train in NeRF mode before switching to RFL mode. Default is 1000. Only used if --train_mode is set to 'rfl'.")
45-
parser.add_argument("--no_rflrelax_training_schedule", action="store_true", help="Disable RFL training schedule for RflRelax mode (active between steps 15k-30k).")
46+
parser.add_argument("--rflrelax_begin_step", type=int, default=15000, help="First training step in which RflRelax mode is used. Default is 15000. Only used if --train_mode is set to 'rflrelax'.")
47+
parser.add_argument("--rflrelax_end_step", type=int, default=30000, help="Last training step in which RflRelax mode is used. Default is 30000. Only used if --train_mode is set to 'rflrelax'.")
4648

4749
parser.add_argument("--screenshot_transforms", default="", help="Path to a nerf style transforms.json from which to save screenshots.")
4850
parser.add_argument("--screenshot_frames", nargs="*", help="Which frame(s) to take screenshots of.")
@@ -159,8 +161,6 @@ def get_scene(scene):
159161
else:
160162
raise ValueError(f"Unknown train mode: {args.train_mode}")
161163

162-
testbed.nerf.training.rfl_warmup_steps = args.rfl_warmup_steps
163-
164164
if args.nerf_compatibility:
165165
print(f"NeRF compatibility mode enabled")
166166

@@ -183,8 +183,11 @@ def get_scene(scene):
183183
testbed.nerf.training.random_bg_color = False
184184

185185
# Ensure that the training mode is set to NeRF.
186+
if testbed.nerf.training.train_mode != ngp.TrainMode.Nerf:
187+
print(f"Warning: forcing train mode to NeRF for nerf compatibility (was {testbed.nerf.training.train_mode})")
186188
testbed.nerf.training.train_mode = ngp.TrainMode.Nerf
187189

190+
188191
old_training_step = 0
189192
n_steps = args.n_steps
190193

@@ -196,14 +199,15 @@ def get_scene(scene):
196199

197200
original_train_mode = ngp.TrainMode(testbed.nerf.training.train_mode)
198201
prev_train_mode = original_train_mode
202+
use_training_schedule = True
199203

200204
tqdm_last_update = 0
201205
if n_steps > 0:
202206
with tqdm(desc="Training", total=n_steps, unit="steps") as t:
203207
while testbed.frame():
204-
if prev_train_mode != testbed.nerf.training.train_mode and not args.no_rflrelax_training_schedule:
205-
print("Disabling RflRelax training schedule due to UI train mode change")
206-
args.no_rflrelax_training_schedule = True
208+
if prev_train_mode != testbed.nerf.training.train_mode and use_training_schedule:
209+
print("Disabling Rfl/RflRelax training schedule due to UI train mode change")
210+
use_training_schedule = False
207211

208212
if testbed.want_repl():
209213
repl(testbed)
@@ -221,13 +225,21 @@ def get_scene(scene):
221225
t.reset()
222226

223227
# Rfl-relax training schedule
224-
progress_fraction = float(testbed.training_step) / n_steps
225-
if original_train_mode == ngp.TrainMode.RflRelax and not args.no_rflrelax_training_schedule:
226-
# By default only enable RflRelax mode between 15k and 30k steps
227-
if 3/7 <= progress_fraction < 6/7:
228-
testbed.nerf.training.train_mode = ngp.TrainMode.RflRelax
229-
else:
230-
testbed.nerf.training.train_mode = ngp.TrainMode.Nerf
228+
if use_training_schedule:
229+
if original_train_mode == ngp.TrainMode.RflRelax:
230+
# By default only enable RflRelax mode in the middle of training. Start with NeRF mode,
231+
# then switch to RflRelax mode to "sueface-ify" the scene, then switch back to NeRF mode
232+
# at the very and for fine tuning.
233+
if args.rflrelax_begin_step <= testbed.training_step < args.rflrelax_end_step:
234+
testbed.nerf.training.train_mode = ngp.TrainMode.RflRelax
235+
else:
236+
testbed.nerf.training.train_mode = ngp.TrainMode.Nerf
237+
elif original_train_mode == ngp.TrainMode.Rfl:
238+
# Start in NeRF mode, then switch to RFL mode after a warmup period
239+
if testbed.training_step > args.rfl_warmup_steps:
240+
testbed.nerf.training.train_mode = ngp.TrainMode.Rfl
241+
else:
242+
testbed.nerf.training.train_mode = ngp.TrainMode.Nerf
231243

232244
now = time.monotonic()
233245
if now - tqdm_last_update > 0.1:

src/nerf_loader.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,8 @@ NerfDataset load_nerf(const std::vector<fs::path>& jsonpaths, float sharpen_amou
442442
}
443443

444444
if (json.contains("from_mitsuba")) {
445-
result.from_mitsuba = bool(json["from_mitsuba"]);
446-
}
445+
result.from_mitsuba = bool(json["from_mitsuba"]);
446+
}
447447

448448
if (json.contains("fix_premult")) {
449449
fix_premult = (bool)json["fix_premult"];

src/python_api.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,6 @@ PYBIND11_MODULE(pyngp, m) {
800800
.def_readwrite("near_distance", &Testbed::Nerf::Training::near_distance)
801801
.def_readwrite("density_grid_decay", &Testbed::Nerf::Training::density_grid_decay)
802802
.def_readwrite("train_mode", &Testbed::Nerf::Training::train_mode)
803-
.def_readwrite("rfl_warmup_steps", &Testbed::Nerf::Training::rfl_warmup_steps)
804803
.def_readwrite("extrinsic_l2_reg", &Testbed::Nerf::Training::extrinsic_l2_reg)
805804
.def_readwrite("extrinsic_learning_rate", &Testbed::Nerf::Training::extrinsic_learning_rate)
806805
.def_readwrite("intrinsic_l2_reg", &Testbed::Nerf::Training::intrinsic_l2_reg)

0 commit comments

Comments
 (0)