Skip to content

Commit 01de42e

Browse files
committed
fix(delegate): sort by fields from delegate base
1 parent bf805f0 commit 01de42e

32 files changed

+1878
-10
lines changed

packages/language/src/index.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export class DocumentLoadError extends Error {
2020

2121
export async function loadDocument(
2222
fileName: string,
23-
pluginModelFiles: string[] = [],
23+
additionalModelFiles: string[] = [],
2424
): Promise<
2525
{ success: true; model: Model; warnings: string[] } | { success: false; errors: string[]; warnings: string[] }
2626
> {
@@ -50,9 +50,9 @@ export async function loadDocument(
5050
URI.file(path.resolve(path.join(_dirname, '../res', STD_LIB_MODULE_NAME))),
5151
);
5252

53-
// load plugin model files
53+
// load additional model files
5454
const pluginDocs = await Promise.all(
55-
pluginModelFiles.map((file) =>
55+
additionalModelFiles.map((file) =>
5656
services.shared.workspace.LangiumDocuments.getOrCreateDocument(URI.file(path.resolve(file))),
5757
),
5858
);

packages/plugins/policy/src/policy-handler.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,20 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
127127
// --- Post mutation work ---
128128

129129
if (hasPostUpdatePolicies && result.rows.length > 0) {
130+
// verify if before-update rows and post-update rows still id-match
131+
if (beforeUpdateInfo) {
132+
invariant(beforeUpdateInfo.rows.length === result.rows.length);
133+
const idFields = QueryUtils.requireIdFields(this.client.$schema, mutationModel);
134+
for (const postRow of result.rows) {
135+
const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f]));
136+
if (!beforeRow) {
137+
throw new QueryError(
138+
'Before-update and after-update rows do not match by id. If you have post-update policies on a model, updating id fields is not supported.',
139+
);
140+
}
141+
}
142+
}
143+
130144
// entities updated filter
131145
const idConditions = this.buildIdConditions(mutationModel, result.rows);
132146

packages/runtime/src/client/crud/dialects/base-dialect.ts

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,15 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
826826
}
827827

828828
let result = query;
829+
830+
const buildFieldRef = (model: string, field: string, modelAlias: string) => {
831+
const fieldDef = requireField(this.schema, model, field);
832+
const eb = expressionBuilder<any, any>();
833+
return fieldDef.originModel
834+
? this.fieldRef(fieldDef.originModel, field, eb, fieldDef.originModel)
835+
: this.fieldRef(model, field, eb, modelAlias);
836+
};
837+
829838
enumerate(orderBy).forEach((orderBy) => {
830839
for (const [field, value] of Object.entries<any>(orderBy)) {
831840
if (!value) {
@@ -838,8 +847,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
838847
for (const [k, v] of Object.entries<SortOrder>(value)) {
839848
invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`);
840849
result = result.orderBy(
841-
(eb) =>
842-
aggregate(eb, this.fieldRef(model, k, eb, modelAlias), field as AGGREGATE_OPERATORS),
850+
(eb) => aggregate(eb, buildFieldRef(model, k, modelAlias), field as AGGREGATE_OPERATORS),
843851
sql.raw(this.negateSort(v, negated)),
844852
);
845853
}
@@ -852,7 +860,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
852860
for (const [k, v] of Object.entries<string>(value)) {
853861
invariant(v === 'asc' || v === 'desc', `invalid orderBy value for field "${field}"`);
854862
result = result.orderBy(
855-
(eb) => eb.fn.count(this.fieldRef(model, k, eb, modelAlias)),
863+
(eb) => eb.fn.count(buildFieldRef(model, k, modelAlias)),
856864
sql.raw(this.negateSort(v, negated)),
857865
);
858866
}
@@ -865,7 +873,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
865873
const fieldDef = requireField(this.schema, model, field);
866874

867875
if (!fieldDef.relation) {
868-
const fieldRef = this.fieldRef(model, field, expressionBuilder(), modelAlias);
876+
const fieldRef = buildFieldRef(model, field, modelAlias);
869877
if (value === 'asc' || value === 'desc') {
870878
result = result.orderBy(fieldRef, this.negateSort(value, negated));
871879
} else if (

packages/testtools/src/schema.ts

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,29 @@ export async function generateTsSchemaInPlace(schemaPath: string) {
9191
return compileAndLoad(workDir);
9292
}
9393

94-
export async function loadSchema(schema: string) {
94+
export async function loadSchema(schema: string, additionalSchemas?: Record<string, string>) {
9595
if (!schema.includes('datasource ')) {
9696
schema = `${makePrelude('sqlite')}\n\n${schema}`;
9797
}
9898

99+
// create a temp folder
100+
const tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'zenstack-schema'));
101+
99102
// create a temp file
100-
const tempFile = path.join(os.tmpdir(), `zenstack-schema-${crypto.randomUUID()}.zmodel`);
103+
const tempFile = path.join(tempDir, `schema.zmodel`);
101104
fs.writeFileSync(tempFile, schema);
105+
106+
if (additionalSchemas) {
107+
for (const [fileName, content] of Object.entries(additionalSchemas)) {
108+
let name = fileName;
109+
if (!name.endsWith('.zmodel')) {
110+
name += '.zmodel';
111+
}
112+
const filePath = path.join(tempDir, name);
113+
fs.writeFileSync(filePath, content);
114+
}
115+
}
116+
102117
const r = await loadDocument(tempFile);
103118
expect(r).toSatisfy(
104119
(r) => r.success,
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import { createPolicyTestClient } from '@zenstackhq/testtools';
2+
import { describe, expect, it } from 'vitest';
3+
4+
// TODO: field-level policy support
5+
describe.skip('Regression for issue 1014', () => {
6+
it('update', async () => {
7+
const db = await createPolicyTestClient(
8+
`
9+
model User {
10+
id Int @id() @default(autoincrement())
11+
name String
12+
posts Post[]
13+
}
14+
15+
model Post {
16+
id Int @id() @default(autoincrement())
17+
title String
18+
content String?
19+
author User? @relation(fields: [authorId], references: [id])
20+
authorId Int? @allow('update', true, true)
21+
22+
@@allow('read', true)
23+
}
24+
`,
25+
);
26+
27+
const user = await db.$unuseAll().user.create({ data: { name: 'User1' } });
28+
const post = await db.$unuseAll().post.create({ data: { title: 'Post1' } });
29+
await expect(db.post.update({ where: { id: post.id }, data: { authorId: user.id } })).toResolveTruthy();
30+
});
31+
32+
it('read', async () => {
33+
const db = await createPolicyTestClient(
34+
`
35+
model Post {
36+
id Int @id() @default(autoincrement())
37+
title String @allow('read', true, true)
38+
content String
39+
}
40+
`,
41+
);
42+
43+
const post = await db.$unuseAll().post.create({ data: { title: 'Post1', content: 'Content' } });
44+
await expect(db.post.findUnique({ where: { id: post.id } })).toResolveNull();
45+
await expect(db.post.findUnique({ where: { id: post.id }, select: { title: true } })).resolves.toEqual({
46+
title: 'Post1',
47+
});
48+
});
49+
});
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import { createTestClient } from '@zenstackhq/testtools';
2+
import { it } from 'vitest';
3+
4+
it('verifies issue 1058', async () => {
5+
const schema = `
6+
model User {
7+
id String @id @default(cuid())
8+
name String
9+
10+
userRankings UserRanking[]
11+
userFavorites UserFavorite[]
12+
}
13+
14+
model Entity {
15+
id String @id @default(cuid())
16+
name String
17+
type String
18+
userRankings UserRanking[]
19+
userFavorites UserFavorite[]
20+
21+
@@delegate(type)
22+
}
23+
24+
model Person extends Entity {
25+
}
26+
27+
model Studio extends Entity {
28+
}
29+
30+
31+
model UserRanking {
32+
id String @id @default(cuid())
33+
rank Int
34+
35+
entityId String
36+
entity Entity @relation(fields: [entityId], references: [id], onUpdate: NoAction)
37+
userId String
38+
user User @relation(fields: [userId], references: [id], onDelete: Cascade, onUpdate: NoAction)
39+
}
40+
41+
model UserFavorite {
42+
id String @id @default(cuid())
43+
44+
entityId String
45+
entity Entity @relation(fields: [entityId], references: [id], onUpdate: NoAction)
46+
userId String
47+
user User @relation(fields: [userId], references: [id], onDelete: Cascade, onUpdate: NoAction)
48+
}
49+
`;
50+
51+
await createTestClient(schema, { provider: 'postgresql' });
52+
});
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import { createPolicyTestClient } from '@zenstackhq/testtools';
2+
import { describe, expect, it } from 'vitest';
3+
4+
describe('Regression for issue 1078', () => {
5+
it('regression1', async () => {
6+
const db = await createPolicyTestClient(
7+
`
8+
model Counter {
9+
id String @id
10+
11+
name String
12+
value Int
13+
14+
@@validate(value >= 0)
15+
@@allow('all', true)
16+
}
17+
`,
18+
);
19+
20+
await expect(
21+
db.counter.create({
22+
data: { id: '1', name: 'It should create', value: 1 },
23+
}),
24+
).toResolveTruthy();
25+
26+
//! This query fails validation
27+
await expect(
28+
db.counter.update({
29+
where: { id: '1' },
30+
data: { name: 'It should update' },
31+
}),
32+
).toResolveTruthy();
33+
});
34+
35+
// TODO: field-level policy support
36+
it.skip('regression2', async () => {
37+
const db = await createPolicyTestClient(
38+
`
39+
model Post {
40+
id Int @id() @default(autoincrement())
41+
title String @allow('read', true, true)
42+
content String
43+
}
44+
`,
45+
);
46+
47+
const post = await db.$unuseAll().post.create({ data: { title: 'Post1', content: 'Content' } });
48+
await expect(db.post.findUnique({ where: { id: post.id } })).toResolveNull();
49+
await expect(db.post.findUnique({ where: { id: post.id }, select: { title: true } })).resolves.toEqual({
50+
title: 'Post1',
51+
});
52+
});
53+
});

0 commit comments

Comments
 (0)