@@ -619,22 +619,22 @@ function _pairwise!(r::AbstractMatrix, dist::Union{SqEuclidean,Euclidean},
619619 R = inplace ? mul! (r, a' , b) : a' b
620620 sa2 = sum (abs2, a, dims= 1 )
621621 sb2 = sum (abs2, b, dims= 1 )
622- threshT = convert ( eltype (r), dist . thresh )
623- @inbounds if threshT <= 0
622+ z² = zero ( real ( eltype (R)) )
623+ @inbounds if dist . thresh <= 0
624624 # If there's no chance of triggering the threshold, we can use @simd
625625 for j = 1 : nb
626626 sb = sb2[j]
627627 @simd for i = 1 : na
628- r[i, j] = eval_end (dist, (max (sa2[i] + sb - 2 real (R[i, j]), 0 )))
628+ r[i, j] = eval_end (dist, (max (sa2[i] + sb - 2 real (R[i, j]), z² )))
629629 end
630630 end
631631 else
632632 for j = 1 : nb
633633 sb = sb2[j]
634634 for i = 1 : na
635635 selfterms = sa2[i] + sb
636- v = max (selfterms - 2 real (R[i, j]), 0 )
637- if v < threshT * selfterms
636+ v = max (selfterms - 2 real (R[i, j]), z² )
637+ if v < dist . thresh * selfterms
638638 # The distance is likely to be inaccurate, recalculate directly
639639 # This reflects the following:
640640 # while sqrt(x+ϵ) ≈ sqrt(x) + O(ϵ/sqrt(x)) when |x| >> ϵ,
@@ -658,22 +658,23 @@ function _pairwise!(r::AbstractMatrix, dist::Union{SqEuclidean,Euclidean}, a::Ab
658658 # the following checks if a'*b can be stored in r directly, it fails for complex eltypes
659659 R = inplace ? mul! (r, a' , a) : a' a
660660 sa2 = sum (abs2, a, dims= 1 )
661- threshT = convert (eltype (r), dist. thresh)
661+ safe = dist. thresh <= 0
662+ z² = zero (real (eltype (R)))
662663 @inbounds for j = 1 : n
663664 for i = 1 : (j - 1 )
664665 r[i, j] = r[j, i]
665666 end
666- r[j, j] = 0
667+ r[j, j] = zero ( eltype (r))
667668 sa2j = sa2[j]
668- if threshT <= 0
669+ if safe
669670 @simd for i = (j + 1 ): n
670- r[i, j] = eval_end (dist, (max (sa2[i] + sa2j - 2 real (R[i, j]), 0 )))
671+ r[i, j] = eval_end (dist, (max (sa2[i] + sa2j - 2 real (R[i, j]), z² )))
671672 end
672673 else
673674 for i = (j + 1 ): n
674675 selfterms = sa2[i] + sa2j
675- v = max (selfterms - 2 real (R[i, j]), 0 )
676- if v < threshT * selfterms
676+ v = max (selfterms - 2 real (R[i, j]), z² )
677+ if v < dist . thresh * selfterms
677678 v = zero (v)
678679 for k = 1 : m
679680 v += abs2 (a[k, i] - a[k, j])
@@ -698,9 +699,10 @@ function _pairwise!(r::AbstractMatrix, dist::Union{WeightedSqEuclidean,WeightedE
698699 # the following checks if a'*b can be stored in r directly, it fails for complex eltypes
699700 inplace = promote_type (eltype (r), typeof (oneunit (eltype (a))' oneunit (eltype (b)))) === eltype (r)
700701 R = inplace ? mul! (r, a' , w .* b) : a' * Diagonal (w)* b
702+ z² = zero (real (eltype (R)))
701703 for j = 1 : nb
702704 @simd for i = 1 : na
703- @inbounds r[i, j] = eval_end (dist, max (sa2[i] + sb2[j] - 2 real (R[i, j]), 0 ))
705+ @inbounds r[i, j] = eval_end (dist, max (sa2[i] + sb2[j] - 2 real (R[i, j]), z² ))
704706 end
705707 end
706708 r
@@ -715,14 +717,15 @@ function _pairwise!(r::AbstractMatrix, dist::Union{WeightedSqEuclidean,WeightedE
715717 # the following checks if a'*b can be stored in r directly, it fails for complex eltypes
716718 inplace = promote_type (eltype (r), typeof (oneunit (eltype (a))' oneunit (eltype (a)))) === eltype (r)
717719 R = inplace ? mul! (r, a' , w .* a) : a' * Diagonal (w)* a
720+ z² = zero (real (eltype (R)))
718721
719722 @inbounds for j = 1 : n
720723 for i = 1 : (j - 1 )
721724 r[i, j] = r[j, i]
722725 end
723- r[j, j] = 0
726+ r[j, j] = zero ( eltype (r))
724727 @simd for i = (j + 1 ): n
725- r[i, j] = eval_end (dist, max (sa2[i] + sa2[j] - 2 real (R[i, j]), 0 ))
728+ r[i, j] = eval_end (dist, max (sa2[i] + sa2[j] - 2 real (R[i, j]), z² ))
726729 end
727730 end
728731 r
@@ -734,28 +737,30 @@ function _pairwise!(r::AbstractMatrix, ::CosineDist,
734737 a:: AbstractMatrix , b:: AbstractMatrix )
735738 require_one_based_indexing (r, a, b)
736739 m, na, nb = get_pairwise_dims (r, a, b)
737- mul! (r, a' , b)
740+ inplace = promote_type (eltype (r), typeof (oneunit (eltype (a))' oneunit (eltype (b)))) === eltype (r)
741+ R = inplace ? mul! (r, a' , b) : a' b
738742 ra = norm_percol (a)
739743 rb = norm_percol (b)
740744 for j = 1 : nb
741745 @simd for i = 1 : na
742- @inbounds r[i, j] = max (1 - r [i, j] / (ra[i] * rb[j]), 0 )
746+ @inbounds r[i, j] = max (1 - R [i, j] / (ra[i] * rb[j]), 0 )
743747 end
744748 end
745749 r
746750end
747751function _pairwise! (r:: AbstractMatrix , :: CosineDist , a:: AbstractMatrix )
748752 require_one_based_indexing (r, a)
749753 m, n = get_pairwise_dims (r, a)
750- mul! (r, a' , a)
754+ inplace = promote_type (eltype (r), typeof (oneunit (eltype (a))' oneunit (eltype (a)))) === eltype (r)
755+ R = inplace ? mul! (r, a' , a) : a' a
751756 ra = norm_percol (a)
752757 @inbounds for j = 1 : n
753758 for i = 1 : (j - 1 )
754759 r[i, j] = r[j, i]
755760 end
756- r[j, j] = 0
761+ r[j, j] = zero ( eltype (r))
757762 @simd for i = j + 1 : n
758- r[i, j] = max (1 - r [i, j] / (ra[i] * ra[j]), 0 )
763+ r[i, j] = max (1 - R [i, j] / (ra[i] * ra[j]), 0 )
759764 end
760765 end
761766 r
0 commit comments