Skip to content

Commit 74cd488

Browse files
committed
Try to get IVFFLAT indexing working again
1 parent b819026 commit 74cd488

File tree

9 files changed

+76
-56
lines changed

9 files changed

+76
-56
lines changed

regress/sql/vector.sql

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,20 @@ SELECT gtype_build_list('i'::text, tovector('"[0, 0, 0]"'::gtype));
188188

189189
SELECT gtype_build_map('i'::text, tovector('"[0, 0, 0]"'::gtype))->'"i"';
190190

191+
SELECT create_vlabel('vector', 'vlabel2');
192+
193+
--CREATE (v:vlabel2 {"i": tovector('[0, 0, 0]')})
194+
--RETURN v, v->'i' <-> tovector('[1, 1, 1]'), v->'i' <=> tovector('[1, 2, 3]'), v->'i' <#> tovector('[1, 1, 1]');
195+
196+
--CREATE (v:vlabel2 {"i": tovector('[1, 0, 0]')})
197+
--RETURN v, v.i <-> tovector('[1, 1, 1]'), v.i <=> tovector('[1, 2, 3]'), v.i <#> tovector('[1, 1, 1]');
198+
199+
200+
SELECT create_ivfflat_ip_ops_index('vector', 'vlabel2', 'i', 3, 100);
201+
202+
--CREATE INDEX vec_idx ON vector.vlabel
203+
191204
--
192-
-- cleanup
205+
-- clean up
193206
--
194207
DROP GRAPH vector CASCADE;

src/backend/access/ivfbuild.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState
185185
// Find the list that minimizes the distance
186186
for (int i = 0; i < centers->length; i++)
187187
{
188-
distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, value, PointerGetDatum(VectorArrayGet(centers, i))));
188+
distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, value, PointerGetDatum(&centers->items[i])));
189189

190190
if (distance < minDistance)
191191
{
@@ -388,7 +388,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In
388388
buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc);
389389
#endif
390390

391-
buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions);
391+
buildstate->centers = GtypeVectorArrayInit(buildstate->lists, buildstate->dimensions);
392392
buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists);
393393

394394
// Reuse for each tuple
@@ -445,7 +445,7 @@ ComputeCenters(IvfflatBuildState * buildstate)
445445

446446
// Sample rows
447447
// TODO Ensure within maintenance_work_mem
448-
buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions);
448+
buildstate->samples = GtypeVectorArrayInit(numSamples, buildstate->dimensions);
449449
if (buildstate->heap != NULL) {
450450
SampleRows(buildstate);
451451

@@ -510,7 +510,7 @@ CreateListPages(Relation index, VectorArray centers, int dimensions, int lists,
510510
// Load list
511511
list->startPage = InvalidBlockNumber;
512512
list->insertPage = InvalidBlockNumber;
513-
memcpy(&list->center, VectorArrayGet(centers, i), VECTOR_SIZE(dimensions));
513+
memcpy(&list->center, &centers->items[i], VECTOR_SIZE(dimensions));
514514

515515
// Ensure free space
516516
if (PageGetFreeSpace(page) < itemsz)
@@ -560,7 +560,7 @@ PrintKmeansMetrics(IvfflatBuildState * buildstate) {
560560
if (j == i)
561561
continue;
562562

563-
distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, PointerGetDatum(VectorArrayGet(buildstate->centers, i)), PointerGetDatum(VectorArrayGet(buildstate->centers, j))));
563+
distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, PointerGetDatum(GTypeVectorArrayGet(buildstate->centers, i)), PointerGetDatum(GTypeVectorArrayGet(buildstate->centers, j))));
564564
distance = (buildstate->listSums[i] + buildstate->listSums[j]) / distance;
565565

566566
if (distance > max)

src/backend/access/ivfkmeans.c

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#include "access/ivfflat.h"
77
#include "miscadmin.h"
88

9+
static int
10+
GtypeCompareVectors(const void *a, const void *b);
11+
912
/*
1013
* Initialize with kmeans++
1114
*
@@ -29,7 +32,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float8 *lo
2932
collation = index->rd_indcollation[0];
3033

3134
// Choose an initial center uniformly at random
32-
VectorArraySet(centers, 0, VectorArrayGet(samples, RandomInt() % samples->length));
35+
VectorArraySet(centers, 0, &samples->items[RandomInt() % samples->length]);
3336
centers->length++;
3437

3538
for (j = 0; j < numSamples; j++)
@@ -41,11 +44,11 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float8 *lo
4144
sum = 0.0;
4245

4346
for (j = 0; j < numSamples; j++) {
44-
vec = VectorArrayGet(samples, j);
45-
47+
vec = &samples->items[j];
48+
4649
// Only need to compute distance for new center
4750
// TODO Use triangle inequality to reduce distance calculations
48-
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, i))));
51+
distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(&centers->items[i])));
4952

5053
// Set lower bound
5154
lowerBound[j * numCenters + i] = distance;
@@ -71,7 +74,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float8 *lo
7174
break;
7275
}
7376

74-
VectorArraySet(centers, i + 1, VectorArrayGet(samples, j));
77+
VectorArraySet(centers, i + 1, &samples->items[j]);
7578
centers->length++;
7679
}
7780

@@ -96,7 +99,7 @@ ApplyNorm(FmgrInfo *normprocinfo, Oid collation, gtype * vec) {
9699
* Compare vectors
97100
*/
98101
static int
99-
CompareVectors(const void *a, const void *b) {
102+
GtypeCompareVectors(const void *a, const void *b) {
100103
return gtype_vector_cmp((Vector *) a, (Vector *) b);
101104
}
102105

@@ -112,11 +115,11 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers) {
112115

113116
// Copy existing vectors while avoiding duplicates
114117
if (samples->length > 0) {
115-
qsort(samples->items, samples->length, VECTOR_SIZE(samples->dim), CompareVectors);
118+
qsort(samples->items, samples->length, VECTOR_SIZE(samples->dim), GtypeCompareVectors);
116119
for (int i = 0; i < samples->length; i++) {
117-
vec = VectorArrayGet(samples, i);
120+
vec = &samples->items[i];//GTypeVectorArrayGet(samples, i);
118121

119-
if (i == 0 || CompareVectors(vec, VectorArrayGet(samples, i - 1)) != 0) {
122+
if (i == 0 || GtypeCompareVectors(vec, &samples->items[i -1]) != 0) {
120123
VectorArraySet(centers, centers->length, vec);
121124
centers->length++;
122125
}
@@ -125,8 +128,9 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers) {
125128

126129
// Fill remaining with random data
127130
while (centers->length < centers->maxlen) {
128-
vec = VectorArrayGet(centers, centers->length);
129-
131+
//vec = GTypeVectorArrayGet(centers, centers->length);
132+
vec = &(centers->items[centers->length]);
133+
130134
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
131135
vec->root.header = dimensions | GT_FEXTENDED_COMPOSITE;
132136
vec->root.children[0] = GT_HEADER_VECTOR;
@@ -221,9 +225,9 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
221225
halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE);
222226
newcdist = palloc(newcdistSize);
223227

224-
newCenters = VectorArrayInit(numCenters, dimensions);
228+
newCenters = GtypeVectorArrayInit(numCenters, dimensions);
225229
for (j = 0; j < numCenters; j++) {
226-
vec = VectorArrayGet(newCenters, j);
230+
vec = &newCenters->items[j];
227231
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
228232
vec->root.header = dimensions | GT_FEXTENDED_COMPOSITE;
229233
vec->root.children[0] = GT_HEADER_VECTOR;
@@ -263,11 +267,10 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
263267
// Step 1: For all centers, compute distance
264268
for (j = 0; j < numCenters; j++)
265269
{
266-
vec = VectorArrayGet(centers, j);
267-
270+
vec = &(centers->items[j]);
268271
for (k = j + 1; k < numCenters; k++)
269272
{
270-
distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
273+
distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(&(centers->items[k]))));
271274
halfcdist[j * numCenters + k] = distance;
272275
halfcdist[k * numCenters + j] = distance;
273276
}
@@ -313,12 +316,12 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
313316
if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k])
314317
continue;
315318

316-
vec = VectorArrayGet(samples, j);
317-
319+
vec = &samples->items[j];
320+
318321
// Step 3a
319322
if (rj)
320323
{
321-
dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, closestCenters[j]))));
324+
dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(&(centers->items[closestCenters[j]]))));
322325

323326
// d(x,c(x)) computed, which is a form of d(x,c)
324327
lowerBound[j * numCenters + closestCenters[j]] = dxcx;
@@ -332,7 +335,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
332335
// Step 3b
333336
if (dxcx > lowerBound[j * numCenters + k] || dxcx > halfcdist[closestCenters[j] * numCenters + k])
334337
{
335-
dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k))));
338+
dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(&(centers->items[k]))));
336339

337340
// d(x,c) calculated
338341
lowerBound[j * numCenters + k] = dxc;
@@ -354,7 +357,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
354357
// Step 4: For each center c, let m(c) be mean of all points assigned
355358
for (j = 0; j < numCenters; j++)
356359
{
357-
vec = VectorArrayGet(newCenters, j);
360+
vec =&(newCenters->items[j]);
358361
for (k = 0; k < dimensions; k++)
359362
*((float8 *)&vec->root.children[1 + (k * sizeof(float8))]) = 0.0;
360363

@@ -363,11 +366,11 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
363366

364367
for (j = 0; j < numSamples; j++)
365368
{
366-
vec = VectorArrayGet(samples, j);
369+
vec = &samples->items[j];
367370
closestCenter = closestCenters[j];
368371

369372
// Increment sum and count of closest center
370-
newCenter = VectorArrayGet(newCenters, closestCenter);
373+
newCenter = GTypeVectorArrayGet(newCenters, closestCenter);
371374
for (k = 0; k < dimensions; k++)
372375
*((float8 *)&newCenter->root.children[1 + (k * sizeof(float8))]) += *((float8 *)(&vec->root.children[1 + (k * sizeof(float8))]));
373376

@@ -376,7 +379,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
376379

377380
for (j = 0; j < numCenters; j++)
378381
{
379-
vec = VectorArrayGet(newCenters, j);
382+
vec = GTypeVectorArrayGet(newCenters, j);
380383

381384
if (centerCounts[j] > 0)
382385
{
@@ -405,7 +408,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
405408

406409
// Step 5
407410
for (j = 0; j < numCenters; j++)
408-
newcdist[j] = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(VectorArrayGet(centers, j)), PointerGetDatum(VectorArrayGet(newCenters, j))));
411+
newcdist[j] = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(&centers->items[j]), PointerGetDatum(&newCenters->items[j])));
409412

410413
for (j = 0; j < numSamples; j++)
411414
{
@@ -427,7 +430,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers)
427430

428431
// Step 7
429432
for (j = 0; j < numCenters; j++)
430-
memcpy(VectorArrayGet(centers, j), VectorArrayGet(newCenters, j), VECTOR_SIZE(dimensions));
433+
memcpy(&(centers->items[j]), &(newCenters->items[j]), VECTOR_SIZE(dimensions));
431434

432435
if (changes == 0 && iteration != 0)
433436
break;
@@ -460,7 +463,8 @@ CheckCenters(Relation index, VectorArray centers)
460463
// Ensure no NaN or infinite values
461464
for (int i = 0; i < centers->length; i++)
462465
{
463-
vec = VectorArrayGet(centers, i);
466+
//vec = GTypeVectorArrayGet(centers, i);
467+
vec = &(centers->items[i]);
464468

465469
for (int j = 0; j < AGT_ROOT_COUNT(vec); j++) {
466470
if (isnan((double) vec->root.children[1 + (j * sizeof(float8))]))
@@ -473,10 +477,11 @@ CheckCenters(Relation index, VectorArray centers)
473477

474478
// Ensure no duplicate centers
475479
// Fine to sort in-place
476-
qsort(centers->items, centers->length, VECTOR_SIZE(centers->dim), CompareVectors);
480+
qsort(centers->items, centers->length, VECTOR_SIZE(centers->dim), GtypeCompareVectors);
477481
for (int i = 1; i < centers->length; i++)
478482
{
479-
if (CompareVectors(VectorArrayGet(centers, i), VectorArrayGet(centers, i - 1)) == 0)
483+
//if (GtypeCompareVectors(GTypeVectorArrayGet(centers, i), GTypeVectorArrayGet(centers, i - 1)) == 0)
484+
if (GtypeCompareVectors(&(centers->items[i]), &(centers->items[i - 1])) == 0)
480485
elog(ERROR, "Duplicate centers detected. Please report a bug.");
481486
}
482487

@@ -489,7 +494,7 @@ CheckCenters(Relation index, VectorArray centers)
489494

490495
for (int i = 0; i < centers->length; i++)
491496
{
492-
norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i))));
497+
norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(&(centers->items[i]))));
493498
if (norm == 0)
494499
elog(ERROR, "Zero norm detected. Please report a bug.");
495500
}

src/backend/access/ivfutils.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,18 +308,18 @@ Datum create_ivfflat_l2_ops_index(PG_FUNCTION_ARGS)
308308
* Allocate a vector array
309309
*/
310310
VectorArray
311-
VectorArrayInit(int maxlen, int dimensions) {
311+
GtypeVectorArrayInit(int maxlen, int dimensions) {
312312
VectorArray res = palloc0(sizeof(VectorArrayData));
313313

314314
res->length = 0;
315315
res->maxlen = maxlen;
316316
res->dim = dimensions;
317317
int gtype_size = VECTOR_SIZE(dimensions) * maxlen;
318318

319-
res->items = palloc_extended(gtype_size * 2, MCXT_ALLOC_ZERO | MCXT_ALLOC_HUGE);
319+
res->items = palloc_extended(gtype_size, MCXT_ALLOC_ZERO | MCXT_ALLOC_HUGE);
320320

321-
for (int i = 0; i < dimensions; i++) {
322-
gtype *vec = VectorArrayGet(res, i);
321+
for (int i = 0; i < maxlen; i++) {
322+
gtype *vec = &res->items[i];// GTypeVectorArrayGet(res, i);
323323

324324
SET_VARSIZE(vec, VECTOR_SIZE(dimensions));
325325
vec->root.header = dimensions | GT_FEXTENDED_COMPOSITE;

src/backend/parser/cypher_gram.y

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17752,7 +17752,7 @@ cypher_in_expr:
1775217752
;
1775317753

1775417754
cypher_b_expr:
17755-
/* cypher_b_expr OR cypher_b_expr
17755+
cypher_b_expr OR cypher_b_expr
1775617756
{
1775717757
$$ = make_or_expr($1, $3, @2);
1775817758
}
@@ -17764,7 +17764,7 @@ cypher_b_expr:
1776417764
{
1776517765
$$ = make_xor_expr($1, $3, @2);
1776617766
}
17767-
| NOT cypher_b_expr
17767+
/*| NOT cypher_b_expr
1776817768
{
1776917769
$$ = make_not_expr($2, @1);
1777017770
}
@@ -17922,11 +17922,11 @@ cypher_expr_atom:
1792217922
else
1792317923
$$ = $2;
1792417924
}
17925-
/*
17925+
1792617926
| cypher_var_name '{' map_proj_list_opt '}' %prec '{'
1792717927
{
1792817928
ereport(ERROR, errmsg("map projections are not yet implemented"));
17929-
}*/
17929+
}
1793017930
| expr_case
1793117931
| cypher_expr_func
1793217932
| '(' cypher_query ')' %prec UNARY_MINUS

src/backend/utils/adt/vector.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ InitVectorGType(int dim)
4848

4949
gtype_value *result = (gtype_value *) palloc(sizeof(gtype_value));
5050

51+
result->type = AGTV_VECTOR;
52+
5153
result->val.vector.dim = dim;
5254

5355
result->val.vector.x = palloc(sizeof(float8) * dim);

src/include/access/ivfflat.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,17 +278,17 @@ typedef IvfflatScanOpaqueData * IvfflatScanOpaque;
278278
#define VECTOR_ARRAY_SIZE(_length, _dim) \
279279
(sizeof(VectorArrayData) + ((_length) * VECTOR_SIZE(_dim)))
280280

281-
#define VECTOR_ARRAY_OFFSET(_arr, _offset) \
281+
#define GTYPE_VECTOR_ARRAY_OFFSET(_arr, _offset) \
282282
((char*) (_arr)->items + ((_offset) * VECTOR_SIZE((_arr)->dim)))
283283

284-
#define VectorArrayGet(_arr, _offset) \
285-
((gtype *) VECTOR_ARRAY_OFFSET(_arr, _offset))
284+
#define GTypeVectorArrayGet(_arr, _offset) \
285+
((gtype *) GTYPE_VECTOR_ARRAY_OFFSET(_arr, _offset))
286286

287287
#define VectorArraySet(_arr, _offset, _val) \
288-
memcpy(VECTOR_ARRAY_OFFSET(_arr, _offset), _val, VECTOR_SIZE((_arr)->dim))
288+
memcpy(GTYPE_VECTOR_ARRAY_OFFSET(_arr, _offset), _val, VECTOR_SIZE((_arr)->dim))
289289

290290
/* Methods */
291-
VectorArray VectorArrayInit(int maxlen, int dimensions);
291+
VectorArray GtypeVectorArrayInit(int maxlen, int dimensions);
292292
void VectorArrayFree(VectorArray arr);
293293
void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers);
294294
FmgrInfo *IvfflatOptionalProcInfo(Relation rel, uint16 procnum);

src/include/utils/gtype.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -538,11 +538,11 @@ struct gtype_value
538538
BOX3D box3d;
539539
SPHEROID spheroid;
540540
GSERIALIZED *gserialized;
541-
TSVector tsvector;
542-
TSQuery tsquery;
543-
RangeType *range;
544-
MultirangeType *multirange;
545-
struct { int len; gtype_container *data; } binary; // Array or object, in on-disk format
541+
TSVector tsvector;
542+
TSQuery tsquery;
543+
RangeType *range;
544+
MultirangeType *multirange;
545+
struct { int len; gtype_container *data; } binary; // Array or object, in on-disk format
546546
} val;
547547
};
548548

src/include/utils/vector.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
* Portions Copyright (c) 2021-2023, pgvector
1818
*/
1919

20-
#ifndef VECTOR_H
21-
#define VECTOR_H
20+
#ifndef POSTGRAPH_VECTOR_H
21+
#define POSTGRAPH_VECTOR_H
2222

2323
#include "postgres.h"
2424

0 commit comments

Comments
 (0)