6
6
#include "access/ivfflat.h"
7
7
#include "miscadmin.h"
8
8
9
+ static int
10
+ GtypeCompareVectors (const void * a , const void * b );
11
+
9
12
/*
10
13
* Initialize with kmeans++
11
14
*
@@ -29,7 +32,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float8 *lo
29
32
collation = index -> rd_indcollation [0 ];
30
33
31
34
// Choose an initial center uniformly at random
32
- VectorArraySet (centers , 0 , VectorArrayGet ( samples , RandomInt () % samples -> length ) );
35
+ VectorArraySet (centers , 0 , & samples -> items [ RandomInt () % samples -> length ] );
33
36
centers -> length ++ ;
34
37
35
38
for (j = 0 ; j < numSamples ; j ++ )
@@ -41,11 +44,11 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float8 *lo
41
44
sum = 0.0 ;
42
45
43
46
for (j = 0 ; j < numSamples ; j ++ ) {
44
- vec = VectorArrayGet ( samples , j ) ;
45
-
47
+ vec = & samples -> items [ j ] ;
48
+
46
49
// Only need to compute distance for new center
47
50
// TODO Use triangle inequality to reduce distance calculations
48
- distance = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (VectorArrayGet ( centers , i ) )));
51
+ distance = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (& centers -> items [ i ] )));
49
52
50
53
// Set lower bound
51
54
lowerBound [j * numCenters + i ] = distance ;
@@ -71,7 +74,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float8 *lo
71
74
break ;
72
75
}
73
76
74
- VectorArraySet (centers , i + 1 , VectorArrayGet ( samples , j ) );
77
+ VectorArraySet (centers , i + 1 , & samples -> items [ j ] );
75
78
centers -> length ++ ;
76
79
}
77
80
@@ -96,7 +99,7 @@ ApplyNorm(FmgrInfo *normprocinfo, Oid collation, gtype * vec) {
96
99
* Compare vectors
97
100
*/
98
101
static int
99
- CompareVectors (const void * a , const void * b ) {
102
+ GtypeCompareVectors (const void * a , const void * b ) {
100
103
return gtype_vector_cmp ((Vector * ) a , (Vector * ) b );
101
104
}
102
105
@@ -112,11 +115,11 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers) {
112
115
113
116
// Copy existing vectors while avoiding duplicates
114
117
if (samples -> length > 0 ) {
115
- qsort (samples -> items , samples -> length , VECTOR_SIZE (samples -> dim ), CompareVectors );
118
+ qsort (samples -> items , samples -> length , VECTOR_SIZE (samples -> dim ), GtypeCompareVectors );
116
119
for (int i = 0 ; i < samples -> length ; i ++ ) {
117
- vec = VectorArrayGet (samples , i );
120
+ vec = & samples -> items [ i ]; //GTypeVectorArrayGet (samples, i);
118
121
119
- if (i == 0 || CompareVectors (vec , VectorArrayGet ( samples , i - 1 ) ) != 0 ) {
122
+ if (i == 0 || GtypeCompareVectors (vec , & samples -> items [ i - 1 ] ) != 0 ) {
120
123
VectorArraySet (centers , centers -> length , vec );
121
124
centers -> length ++ ;
122
125
}
@@ -125,8 +128,9 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers) {
125
128
126
129
// Fill remaining with random data
127
130
while (centers -> length < centers -> maxlen ) {
128
- vec = VectorArrayGet (centers , centers -> length );
129
-
131
+ //vec = GTypeVectorArrayGet(centers, centers->length);
132
+ vec = & (centers -> items [centers -> length ]);
133
+
130
134
SET_VARSIZE (vec , VECTOR_SIZE (dimensions ));
131
135
vec -> root .header = dimensions | GT_FEXTENDED_COMPOSITE ;
132
136
vec -> root .children [0 ] = GT_HEADER_VECTOR ;
@@ -221,9 +225,9 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
221
225
halfcdist = palloc_extended (halfcdistSize , MCXT_ALLOC_HUGE );
222
226
newcdist = palloc (newcdistSize );
223
227
224
- newCenters = VectorArrayInit (numCenters , dimensions );
228
+ newCenters = GtypeVectorArrayInit (numCenters , dimensions );
225
229
for (j = 0 ; j < numCenters ; j ++ ) {
226
- vec = VectorArrayGet ( newCenters , j ) ;
230
+ vec = & newCenters -> items [ j ] ;
227
231
SET_VARSIZE (vec , VECTOR_SIZE (dimensions ));
228
232
vec -> root .header = dimensions | GT_FEXTENDED_COMPOSITE ;
229
233
vec -> root .children [0 ] = GT_HEADER_VECTOR ;
@@ -263,11 +267,10 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
263
267
// Step 1: For all centers, compute distance
264
268
for (j = 0 ; j < numCenters ; j ++ )
265
269
{
266
- vec = VectorArrayGet (centers , j );
267
-
270
+ vec = & (centers -> items [j ]);
268
271
for (k = j + 1 ; k < numCenters ; k ++ )
269
272
{
270
- distance = 0.5 * DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (VectorArrayGet (centers , k ))));
273
+ distance = 0.5 * DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (& (centers -> items [ k ] ))));
271
274
halfcdist [j * numCenters + k ] = distance ;
272
275
halfcdist [k * numCenters + j ] = distance ;
273
276
}
@@ -313,12 +316,12 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
313
316
if (upperBound [j ] <= halfcdist [closestCenters [j ] * numCenters + k ])
314
317
continue ;
315
318
316
- vec = VectorArrayGet ( samples , j ) ;
317
-
319
+ vec = & samples -> items [ j ] ;
320
+
318
321
// Step 3a
319
322
if (rj )
320
323
{
321
- dxcx = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (VectorArrayGet (centers , closestCenters [j ]))));
324
+ dxcx = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (& (centers -> items [ closestCenters [j ] ]))));
322
325
323
326
// d(x,c(x)) computed, which is a form of d(x,c)
324
327
lowerBound [j * numCenters + closestCenters [j ]] = dxcx ;
@@ -332,7 +335,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
332
335
// Step 3b
333
336
if (dxcx > lowerBound [j * numCenters + k ] || dxcx > halfcdist [closestCenters [j ] * numCenters + k ])
334
337
{
335
- dxc = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (VectorArrayGet (centers , k ))));
338
+ dxc = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (vec ), PointerGetDatum (& (centers -> items [ k ] ))));
336
339
337
340
// d(x,c) calculated
338
341
lowerBound [j * numCenters + k ] = dxc ;
@@ -354,7 +357,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
354
357
// Step 4: For each center c, let m(c) be mean of all points assigned
355
358
for (j = 0 ; j < numCenters ; j ++ )
356
359
{
357
- vec = VectorArrayGet (newCenters , j );
360
+ vec = & (newCenters -> items [ j ] );
358
361
for (k = 0 ; k < dimensions ; k ++ )
359
362
* ((float8 * )& vec -> root .children [1 + (k * sizeof (float8 ))]) = 0.0 ;
360
363
@@ -363,11 +366,11 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
363
366
364
367
for (j = 0 ; j < numSamples ; j ++ )
365
368
{
366
- vec = VectorArrayGet ( samples , j );
369
+ vec = & samples -> items [ j ];
367
370
closestCenter = closestCenters [j ];
368
371
369
372
// Increment sum and count of closest center
370
- newCenter = VectorArrayGet (newCenters , closestCenter );
373
+ newCenter = GTypeVectorArrayGet (newCenters , closestCenter );
371
374
for (k = 0 ; k < dimensions ; k ++ )
372
375
* ((float8 * )& newCenter -> root .children [1 + (k * sizeof (float8 ))]) += * ((float8 * )(& vec -> root .children [1 + (k * sizeof (float8 ))]));
373
376
@@ -376,7 +379,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
376
379
377
380
for (j = 0 ; j < numCenters ; j ++ )
378
381
{
379
- vec = VectorArrayGet (newCenters , j );
382
+ vec = GTypeVectorArrayGet (newCenters , j );
380
383
381
384
if (centerCounts [j ] > 0 )
382
385
{
@@ -405,7 +408,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
405
408
406
409
// Step 5
407
410
for (j = 0 ; j < numCenters ; j ++ )
408
- newcdist [j ] = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (VectorArrayGet ( centers , j )) , PointerGetDatum (VectorArrayGet ( newCenters , j ) )));
411
+ newcdist [j ] = DatumGetFloat8 (FunctionCall2Coll (procinfo , collation , PointerGetDatum (& centers -> items [ j ]) , PointerGetDatum (& newCenters -> items [ j ] )));
409
412
410
413
for (j = 0 ; j < numSamples ; j ++ )
411
414
{
@@ -427,7 +430,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
427
430
428
431
// Step 7
429
432
for (j = 0 ; j < numCenters ; j ++ )
430
- memcpy (VectorArrayGet (centers , j ), VectorArrayGet (newCenters , j ), VECTOR_SIZE (dimensions ));
433
+ memcpy (& (centers -> items [ j ] ), & (newCenters -> items [ j ] ), VECTOR_SIZE (dimensions ));
431
434
432
435
if (changes == 0 && iteration != 0 )
433
436
break ;
@@ -460,7 +463,8 @@ CheckCenters(Relation index, VectorArray centers)
460
463
// Ensure no NaN or infinite values
461
464
for (int i = 0 ; i < centers -> length ; i ++ )
462
465
{
463
- vec = VectorArrayGet (centers , i );
466
+ //vec = GTypeVectorArrayGet(centers, i);
467
+ vec = & (centers -> items [i ]);
464
468
465
469
for (int j = 0 ; j < AGT_ROOT_COUNT (vec ); j ++ ) {
466
470
if (isnan ((double ) vec -> root .children [1 + (j * sizeof (float8 ))]))
@@ -473,10 +477,11 @@ CheckCenters(Relation index, VectorArray centers)
473
477
474
478
// Ensure no duplicate centers
475
479
// Fine to sort in-place
476
- qsort (centers -> items , centers -> length , VECTOR_SIZE (centers -> dim ), CompareVectors );
480
+ qsort (centers -> items , centers -> length , VECTOR_SIZE (centers -> dim ), GtypeCompareVectors );
477
481
for (int i = 1 ; i < centers -> length ; i ++ )
478
482
{
479
- if (CompareVectors (VectorArrayGet (centers , i ), VectorArrayGet (centers , i - 1 )) == 0 )
483
+ //if (GtypeCompareVectors(GTypeVectorArrayGet(centers, i), GTypeVectorArrayGet(centers, i - 1)) == 0)
484
+ if (GtypeCompareVectors (& (centers -> items [i ]), & (centers -> items [i - 1 ])) == 0 )
480
485
elog (ERROR , "Duplicate centers detected. Please report a bug." );
481
486
}
482
487
@@ -489,7 +494,7 @@ CheckCenters(Relation index, VectorArray centers)
489
494
490
495
for (int i = 0 ; i < centers -> length ; i ++ )
491
496
{
492
- norm = DatumGetFloat8 (FunctionCall1Coll (normprocinfo , collation , PointerGetDatum (VectorArrayGet (centers , i ))));
497
+ norm = DatumGetFloat8 (FunctionCall1Coll (normprocinfo , collation , PointerGetDatum (& (centers -> items [ i ] ))));
493
498
if (norm == 0 )
494
499
elog (ERROR , "Zero norm detected. Please report a bug." );
495
500
}
0 commit comments