@@ -29,9 +29,9 @@ struct astcenc_rdo_context
2929 std::vector<image_block> m_blocks;
3030
3131 uint32_t m_total_blocks = 0 ;
32- uint32_t m_xblocks = 0 ;
33- uint32_t m_yblocks = 0 ;
34- uint32_t m_zblocks = 0 ;
32+ uint32_t m_image_x = 0 ;
33+ uint32_t m_image_y = 0 ;
34+ uint32_t m_image_z = 0 ;
3535};
3636
3737#define ASTCENC_RDO_SPECIALIZE_DIFF 1
@@ -102,9 +102,9 @@ static uint32_t init_rdo_context(
102102
103103 rdo_ctx.m_blocks .resize (total_blocks);
104104 rdo_ctx.m_total_blocks = total_blocks;
105- rdo_ctx.m_xblocks = xblocks ;
106- rdo_ctx.m_yblocks = yblocks ;
107- rdo_ctx.m_zblocks = zblocks ;
105+ rdo_ctx.m_image_x = image. dim_x ;
106+ rdo_ctx.m_image_y = image. dim_y ;
107+ rdo_ctx.m_image_z = image. dim_z ;
108108
109109 vfloat4 channel_weight = vfloat4 (ctx.config .cw_r_weight ,
110110 ctx.config .cw_g_weight ,
@@ -282,14 +282,22 @@ static float compute_symbolic_block_difference_constant(
282282static float compute_block_mse (
283283 const image_block& orig,
284284 const image_block& cmp,
285+ const block_size_descriptor& bsd,
286+ uint32_t image_x,
287+ uint32_t image_y,
288+ uint32_t image_z,
285289 float orig_scale,
286290 float cmp_scale
287291) {
288292 vfloatacc summav = vfloatacc::zero ();
289293 vint lane_id = vint::lane_id ();
290294 uint32_t texel_count = orig.texel_count ;
291295
292- for (uint32_t i = 0 ; i < texel_count; i += ASTCENC_SIMD_WIDTH)
296+ uint32_t block_x = astc::min (image_x - orig.xpos , (uint32_t )bsd.xdim );
297+ uint32_t block_y = astc::min (image_y - orig.ypos , (uint32_t )bsd.ydim );
298+ uint32_t block_z = astc::min (image_z - orig.zpos , (uint32_t )bsd.zdim );
299+
300+ for (uint32_t i = 0 ; i < texel_count; i += ASTCENC_SIMD_WIDTH, lane_id += vint (ASTCENC_SIMD_WIDTH))
293301 {
294302 vfloat color_orig_r = loada (orig.data_r + i) * orig_scale;
295303 vfloat color_orig_g = loada (orig.data_g + i) * orig_scale;
@@ -318,12 +326,15 @@ static float compute_block_mse(
318326 + color_error_a * orig.channel_weight .lane <3 >();
319327
320328 // Mask off bad lanes
321- vmask mask = lane_id < vint (texel_count);
322- lane_id += vint (ASTCENC_SIMD_WIDTH);
329+ vint lane_id_z (float_to_int (int_to_float (lane_id) / float (bsd.xdim * bsd.ydim )));
330+ vint rem_idx = lane_id - lane_id_z * vint (bsd.xdim * bsd.ydim );
331+ vint lane_id_y = float_to_int (int_to_float (rem_idx) / bsd.xdim );
332+ vint lane_id_x = rem_idx - lane_id_y * vint (bsd.xdim );
333+ vmask mask = (lane_id_x < vint (block_x)) & (lane_id_y < vint (block_y)) & (lane_id_z < vint (block_z));
323334 haccumulate (summav, metric, mask);
324335 }
325336
326- return hadd_s (summav) / texel_count ;
337+ return hadd_s (summav) / (block_x * block_y * block_z) ;
327338}
328339#endif
329340
@@ -376,22 +387,15 @@ static float compute_block_difference(
376387 // ERT expects texel values to be in [0, 255]
377388 return squared_error / blk.texel_count * sqr (255 .0f / 65535 .0f );
378389#else
379- uint32_t block_z = block_idx / (rdo_ctx.m_xblocks * rdo_ctx.m_yblocks );
380- uint32_t slice_idx = block_idx - block_z * rdo_ctx.m_xblocks * rdo_ctx.m_yblocks ;
381- uint32_t block_y = slice_idx / rdo_ctx.m_xblocks ;
382- uint32_t block_x = slice_idx - block_y * rdo_ctx.m_xblocks ;
383-
384390 image_block decoded_blk;
385391 decoded_blk.decode_unorm8 = blk.decode_unorm8 ;
386392 decoded_blk.texel_count = blk.texel_count ;
387393 decoded_blk.channel_weight = blk.channel_weight ;
388394
389- decompress_symbolic_block (ctx.config .profile , *ctx.bsd ,
390- block_x * ctx.bsd ->xdim , block_y * ctx.bsd ->ydim , block_z * ctx.bsd ->zdim ,
391- scb, decoded_blk);
395+ decompress_symbolic_block (ctx.config .profile , *ctx.bsd , blk.xpos , blk.ypos , blk.zpos , scb, decoded_blk);
392396
393397 // ERT expects texel values to be in [0, 255]
394- return compute_block_mse (blk, decoded_blk, 255 .0f / 65535 .0f , 255 .0f );
398+ return compute_block_mse (blk, decoded_blk, *ctx. bsd , rdo_ctx. m_image_x , rdo_ctx. m_image_y , rdo_ctx. m_image_z , 255 .0f / 65535 .0f , 255 .0f );
395399#endif
396400}
397401
0 commit comments