@@ -13,6 +13,7 @@ use crate::dist::*;
1313use crate :: tiling:: PlaneRegion ;
1414use crate :: util:: Pixel ;
1515use crate :: util:: PixelType ;
16+ use std:: arch:: x86_64:: * ;
1617
1718type CdefDistKernelFn = unsafe extern fn (
1819 src : * const u8 ,
@@ -22,6 +23,13 @@ type CdefDistKernelFn = unsafe extern fn(
2223 ret_ptr : * mut u32 ,
2324) ;
2425
26+ type CdefDistKernelHBDFn = unsafe fn (
27+ src : * const u16 ,
28+ src_stride : isize ,
29+ dst : * const u16 ,
30+ dst_stride : isize ,
31+ ) -> ( u32 , u32 , u32 ) ;
32+
2533extern {
2634 fn rav1e_cdef_dist_kernel_4x4_sse2 (
2735 src : * const u8 , src_stride : isize , dst : * const u8 , dst_stride : isize ,
@@ -63,12 +71,12 @@ pub fn cdef_dist_kernel<T: Pixel>(
6371 #[ cfg( feature = "check_asm" ) ]
6472 let ref_dist = call_rust ( ) ;
6573
66- let mut ret_buf = [ 0u32 ; 3 ] ;
67- match T :: type_enum ( ) {
74+ let ( svar, dvar, sse) = match T :: type_enum ( ) {
6875 PixelType :: U8 => {
6976 if let Some ( func) =
7077 CDEF_DIST_KERNEL_FNS [ cpu. as_index ( ) ] [ kernel_fn_index ( w, h) ]
7178 {
79+ let mut ret_buf = [ 0u32 ; 3 ] ;
7280 // SAFETY: Calls Assembly code.
7381 unsafe {
7482 func (
@@ -79,16 +87,30 @@ pub fn cdef_dist_kernel<T: Pixel>(
7987 ret_buf. as_mut_ptr ( ) ,
8088 )
8189 }
90+
91+ ( ret_buf[ 0 ] , ret_buf[ 1 ] , ret_buf[ 2 ] )
8292 } else {
8393 return call_rust ( ) ;
8494 }
8595 }
86- PixelType :: U16 => return call_rust ( ) ,
87- }
88-
89- let svar = ret_buf[ 0 ] ;
90- let dvar = ret_buf[ 1 ] ;
91- let sse = ret_buf[ 2 ] ;
96+ PixelType :: U16 => {
97+ if let Some ( func) =
98+ CDEF_DIST_KERNEL_HBD_FNS [ cpu. as_index ( ) ] [ kernel_fn_index ( w, h) ]
99+ {
100+ // SAFETY: Calls Assembly code.
101+ unsafe {
102+ func (
103+ src. data_ptr ( ) as * const _ ,
104+ T :: to_asm_stride ( src. plane_cfg . stride ) ,
105+ dst. data_ptr ( ) as * const _ ,
106+ T :: to_asm_stride ( dst. plane_cfg . stride ) ,
107+ )
108+ }
109+ } else {
110+ return call_rust ( ) ;
111+ }
112+ }
113+ } ;
92114
93115 let dist = apply_ssim_boost ( sse, svar, dvar, bit_depth) ;
94116 #[ cfg( feature = "check_asm" ) ]
@@ -128,6 +150,98 @@ cpu_function_lookup_table!(
128150 [ SSE2 ]
129151) ;
130152
153+ #[ target_feature( enable = "avx2" ) ]
154+ #[ inline]
155+ unsafe fn mm256_sum_i32 ( ymm : __m256i ) -> i32 {
156+ // We split the vector in half and then add the two halves and sum.
157+ let m1 = _mm256_extracti128_si256 ( ymm, 1 ) ;
158+ let m2 = _mm256_castsi256_si128 ( ymm) ;
159+ let m2 = _mm_add_epi32 ( m2, m1) ;
160+ let m1 = _mm_shuffle_epi32 ( m2, 0b11_10_11_10 ) ;
161+ let m2 = _mm_add_epi32 ( m2, m1) ;
162+ let m1 = _mm_shuffle_epi32 ( m2, 0b01_01_01_01 ) ;
163+ let m2 = _mm_add_epi32 ( m2, m1) ;
164+ _mm_cvtsi128_si32 ( m2)
165+ }
166+
167+ #[ target_feature( enable = "avx2" ) ]
168+ #[ inline]
169+ unsafe fn rav1e_cdef_dist_kernel_8x8_hbd_avx2 (
170+ src : * const u16 , src_stride : isize , dst : * const u16 , dst_stride : isize ,
171+ ) -> ( u32 , u32 , u32 ) {
172+ let src = src as * const u8 ;
173+ let dst = dst as * const u8 ;
174+
175+ unsafe fn sum16 ( src : * const u8 , src_stride : isize ) -> u32 {
176+ let h = 8 ;
177+ let res = ( 0 ..h)
178+ . map ( |row| _mm_load_si128 ( src. offset ( row * src_stride) as * const _ ) )
179+ . reduce ( |a, b| _mm_add_epi16 ( a, b) )
180+ . unwrap ( ) ;
181+
182+ let m32 = _mm256_cvtepi16_epi32 ( res) ;
183+ mm256_sum_i32 ( m32) as u32
184+ }
185+ unsafe fn mpadd32 (
186+ src : * const u8 , src_stride : isize , dst : * const u8 , dst_stride : isize ,
187+ ) -> u32 {
188+ let h = 8 ;
189+ let res = ( 0 ..h / 2 )
190+ . map ( |row| {
191+ let s1 = _mm_load_si128 ( src. offset ( 2 * row * src_stride) as * const _ ) ;
192+ let s2 =
193+ _mm_load_si128 ( src. offset ( ( 2 * row + 1 ) * src_stride) as * const _ ) ;
194+ let s = _mm256_inserti128_si256 ( _mm256_castsi128_si256 ( s1) , s2, 1 ) ;
195+
196+ let d1 = _mm_load_si128 ( dst. offset ( 2 * row * dst_stride) as * const _ ) ;
197+ let d2 =
198+ _mm_load_si128 ( dst. offset ( ( 2 * row + 1 ) * dst_stride) as * const _ ) ;
199+ let d = _mm256_inserti128_si256 ( _mm256_castsi128_si256 ( d1) , d2, 1 ) ;
200+
201+ _mm256_madd_epi16 ( s, d)
202+ } )
203+ . reduce ( |a, b| _mm256_add_epi32 ( a, b) )
204+ . unwrap ( ) ;
205+ mm256_sum_i32 ( res) as u32
206+ }
207+
208+ let sum_s = sum16 ( src, src_stride) ;
209+ let sum_d = sum16 ( dst, dst_stride) ;
210+ let sum_s2 = mpadd32 ( src, src_stride, src, src_stride) ;
211+ let sum_d2 = mpadd32 ( dst, dst_stride, dst, dst_stride) ;
212+ let sum_sd = mpadd32 ( src, src_stride, dst, dst_stride) ;
213+
214+ // To get the distortion, compute sum of squared error and apply a weight
215+ // based on the variance of the two planes.
216+ let sse = sum_d2 + sum_s2 - 2 * sum_sd;
217+
218+ // Convert to 64-bits to avoid overflow when squaring
219+ let sum_s = sum_s as u64 ;
220+ let sum_d = sum_d as u64 ;
221+
222+ let svar = ( sum_s2 as u64 - ( sum_s * sum_s) / 64 ) as u32 ;
223+ let dvar = ( sum_d2 as u64 - ( sum_d * sum_d) / 64 ) as u32 ;
224+
225+ ( svar, dvar, sse)
226+ }
227+
228+ static CDEF_DIST_KERNEL_HBD_FNS_AVX2 : [ Option < CdefDistKernelHBDFn > ;
229+ CDEF_DIST_KERNEL_FNS_LENGTH ] = {
230+ let mut out: [ Option < CdefDistKernelHBDFn > ; CDEF_DIST_KERNEL_FNS_LENGTH ] =
231+ [ None ; CDEF_DIST_KERNEL_FNS_LENGTH ] ;
232+
233+ out[ kernel_fn_index ( 8 , 8 ) ] = Some ( rav1e_cdef_dist_kernel_8x8_hbd_avx2) ;
234+
235+ out
236+ } ;
237+
238+ cpu_function_lookup_table ! (
239+ CDEF_DIST_KERNEL_HBD_FNS :
240+ [ [ Option <CdefDistKernelHBDFn >; CDEF_DIST_KERNEL_FNS_LENGTH ] ] ,
241+ default : [ None ; CDEF_DIST_KERNEL_FNS_LENGTH ] ,
242+ [ AVX2 ]
243+ ) ;
244+
131245#[ cfg( test) ]
132246pub mod test {
133247 use super :: * ;
@@ -204,16 +318,34 @@ pub mod test {
204318 cdef_diff_tester ( 8 , random_planes :: < u8 > ) ;
205319 }
206320
321+ #[ test]
322+ fn cdef_dist_simd_random_hbd ( ) {
323+ cdef_diff_tester ( 10 , random_planes :: < u16 > ) ;
324+ cdef_diff_tester ( 12 , random_planes :: < u16 > ) ;
325+ }
326+
207327 #[ test]
208328 fn cdef_dist_simd_large ( ) {
209329 cdef_diff_tester ( 8 , max_planes :: < u8 > ) ;
210330 }
211331
332+ #[ test]
333+ fn cdef_dist_simd_large_hbd ( ) {
334+ cdef_diff_tester ( 10 , max_planes :: < u16 > ) ;
335+ cdef_diff_tester ( 12 , max_planes :: < u16 > ) ;
336+ }
337+
212338 #[ test]
213339 fn cdef_dist_simd_large_diff ( ) {
214340 cdef_diff_tester ( 8 , max_diff_planes :: < u8 > ) ;
215341 }
216342
343+ #[ test]
344+ fn cdef_dist_simd_large_diff_hbd ( ) {
345+ cdef_diff_tester ( 10 , max_diff_planes :: < u16 > ) ;
346+ cdef_diff_tester ( 12 , max_diff_planes :: < u16 > ) ;
347+ }
348+
217349 fn cdef_diff_tester < T : Pixel > (
218350 bd : usize , gen_planes : fn ( bd : usize ) -> ( Plane < T > , Plane < T > ) ,
219351 ) {
0 commit comments