Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions db/migrations/20260202093502_unique_fields/migration.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/*
Warnings:

- A unique constraint covering the columns `[title,categoryId]` on the table `Post` will be added. If there are existing duplicate values, this will fail.

*/
-- CreateIndex
CREATE UNIQUE INDEX "Post_title_categoryId_key" ON "Post"("title", "categoryId");
6 changes: 4 additions & 2 deletions db/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,21 @@ model Post {
published Boolean @default(false)
viewCount Int @default(0)
categoryId Int
author User? @relation(fields: [authorId], references: [id])
authorId Int?
comments Comment[]
author User? @relation(fields: [authorId], references: [id])
category Category @relation(fields: [categoryId], references: [id])

@@unique([title, categoryId])
}

model Comment {
id Int @id @default(autoincrement())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
content String?
post Post? @relation(fields: [postId], references: [id])
postId Int?
post Post? @relation(fields: [postId], references: [id])
}

model Category {
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "prisma-rls",
"version": "0.5.5",
"version": "0.5.6",
"description": "Prisma client extension for row-level security on any database",
"license": "MIT",
"keywords": [
Expand Down
11 changes: 9 additions & 2 deletions src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { AllOperationsArgs, ExtensionOptions, RecursiveContext, RelationMetadata
import { buildFieldsMap, generateImpossibleWhere, getTransactionClient } from "./utils";

export const createRlsExtension = ({ dmmf, permissionsConfig, context, authorizationError, checkRequiredBelongsTo }: ExtensionOptions) => {
const fieldsMap = buildFieldsMap(dmmf);
const { fieldsMap, uniqueFields } = buildFieldsMap(dmmf);

if (!authorizationError) {
authorizationError = new AuthorizationError();
Expand All @@ -23,7 +23,14 @@ export const createRlsExtension = ({ dmmf, permissionsConfig, context, authoriza
const { model: modelName, operation: operationName, args, query } = allOperationsArgs;

const modelPermissions = permissionsConfig[modelName];
const modelResolver = new ModelResolver(permissionsConfig, context, fieldsMap, authorizationError, !!checkRequiredBelongsTo);
const modelResolver = new ModelResolver(
permissionsConfig,
context,
fieldsMap,
uniqueFields,
authorizationError,
!!checkRequiredBelongsTo,
);

switch (operationName) {
case "findUnique":
Expand Down
70 changes: 47 additions & 23 deletions src/logic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import type {
PrismaTypeMap,
RecursiveContext,
RelationMetadata,
UniqueFieldsMap,
} from "./types";
import {
generateImpossibleWhere,
Expand All @@ -38,6 +39,7 @@ export class ModelResolver {
protected permissionsConfig: PermissionsConfig<PrismaTypeMap, unknown>,
protected context: unknown,
protected fieldsMap: FieldsMap,
protected uniqueFieldsMap: UniqueFieldsMap,
protected authorizationError: Error,
protected checkRequiredBelongsTo: boolean,
) {}
Expand All @@ -62,11 +64,21 @@ export class ModelResolver {
whereUnique: Record<string, unknown>,
) {
if (!permissionDefinition) {
return mergeWhereUnique(this.fieldsMap[modelName], whereUnique, generateImpossibleWhere(this.fieldsMap[modelName]));
return mergeWhereUnique(
this.fieldsMap[modelName],
this.uniqueFieldsMap[modelName],
whereUnique,
generateImpossibleWhere(this.fieldsMap[modelName]),
);
} else if (permissionDefinition === true) {
return whereUnique;
} else {
return mergeWhereUnique(this.fieldsMap[modelName], whereUnique, await resolvePermissionDefinition(permissionDefinition, this.context));
return mergeWhereUnique(
this.fieldsMap[modelName],
this.uniqueFieldsMap[modelName],
whereUnique,
await resolvePermissionDefinition(permissionDefinition, this.context),
);
}
}

Expand Down Expand Up @@ -231,6 +243,7 @@ export class ModelResolver {
const relationModelName = fieldDef.type;
const relationFields = this.fieldsMap[relationModelName];
const relationPermissions = this.permissionsConfig[relationModelName];
const uniqueFields = this.uniqueFieldsMap[relationModelName];

if (fieldDef.isList) {
return mapObjectValues(dataValue as ModelCreateNestedManyInput, async ([actionName, actionValue]) => {
Expand Down Expand Up @@ -258,7 +271,7 @@ export class ModelResolver {
return transformValue(actionValue, async (value) => {
return {
create: await this.resolveCreate(relationModelName, value.create),
where: mergeWhereUnique(relationFields, value.where, generateImpossibleWhere(relationFields)),
where: mergeWhereUnique(relationFields, uniqueFields, value.where, generateImpossibleWhere(relationFields)),
};
});
} else if (relationPermissions.read !== true) {
Expand All @@ -267,7 +280,7 @@ export class ModelResolver {
return transformValue(actionValue, async (value) => {
return {
create: await this.resolveCreate(relationModelName, value.create),
where: mergeWhereUnique(relationFields, value.where, permissionDefinition),
where: mergeWhereUnique(relationFields, uniqueFields, value.where, permissionDefinition),
};
});
} else {
Expand All @@ -282,13 +295,13 @@ export class ModelResolver {
case "connect":
if (!relationPermissions.read) {
return transformValue(actionValue, async (value) => {
return mergeWhereUnique(relationFields, value, generateImpossibleWhere(relationFields));
return mergeWhereUnique(relationFields, uniqueFields, value, generateImpossibleWhere(relationFields));
});
} else if (relationPermissions.read !== true) {
const permissionDefinition = await resolvePermissionDefinition(relationPermissions.read, this.context);

return transformValue(actionValue, async (value) => {
return mergeWhereUnique(relationFields, value, permissionDefinition);
return mergeWhereUnique(relationFields, uniqueFields, value, permissionDefinition);
});
}
break;
Expand Down Expand Up @@ -316,7 +329,7 @@ export class ModelResolver {
if (!relationPermissions.read) {
return {
create: await this.resolveCreate(relationModelName, actionValue.create),
where: mergeWhereUnique(relationFields, actionValue.where, generateImpossibleWhere(relationFields)),
where: mergeWhereUnique(relationFields, uniqueFields, actionValue.where, generateImpossibleWhere(relationFields)),
};
} else if (relationPermissions.read !== true) {
const [create, permissionDefinition] = await Promise.all([
Expand All @@ -326,7 +339,7 @@ export class ModelResolver {

return {
create,
where: mergeWhereUnique(relationFields, actionValue.where, permissionDefinition),
where: mergeWhereUnique(relationFields, uniqueFields, actionValue.where, permissionDefinition),
};
} else {
return {
Expand All @@ -337,9 +350,14 @@ export class ModelResolver {
}
case "connect":
if (!relationPermissions.read) {
return mergeWhereUnique(relationFields, actionValue, generateImpossibleWhere(relationFields));
return mergeWhereUnique(relationFields, uniqueFields, actionValue, generateImpossibleWhere(relationFields));
} else if (relationPermissions.read !== true) {
return mergeWhereUnique(relationFields, actionValue, await resolvePermissionDefinition(relationPermissions.read, this.context));
return mergeWhereUnique(
relationFields,
uniqueFields,
actionValue,
await resolvePermissionDefinition(relationPermissions.read, this.context),
);
}
break;
default:
Expand All @@ -362,6 +380,7 @@ export class ModelResolver {
const relationModelName = fieldDef.type;
const relationFields = this.fieldsMap[relationModelName];
const relationPermissions = this.permissionsConfig[relationModelName];
const uniqueFields = this.uniqueFieldsMap[relationModelName];

if (fieldDef.isList) {
return mapObjectValues(dataValue as ModelUpdateNestedManyInput, async ([actionName, actionValue]) => {
Expand Down Expand Up @@ -389,7 +408,7 @@ export class ModelResolver {
return transformValue(actionValue, async (value) => {
return {
create: await this.resolveCreate(relationModelName, value.create),
where: mergeWhereUnique(relationFields, value.where, generateImpossibleWhere(relationFields)),
where: mergeWhereUnique(relationFields, uniqueFields, value.where, generateImpossibleWhere(relationFields)),
};
});
} else if (relationPermissions.read !== true) {
Expand All @@ -398,7 +417,7 @@ export class ModelResolver {
return transformValue(actionValue, async (value) => {
return {
create: await this.resolveCreate(relationModelName, value.create),
where: mergeWhereUnique(relationFields, value.where, permissionDefinition),
where: mergeWhereUnique(relationFields, uniqueFields, value.where, permissionDefinition),
};
});
} else {
Expand All @@ -415,13 +434,13 @@ export class ModelResolver {
case "disconnect":
if (!relationPermissions.read) {
return transformValue(actionValue, async (value) => {
return mergeWhereUnique(relationFields, value, generateImpossibleWhere(relationFields));
return mergeWhereUnique(relationFields, uniqueFields, value, generateImpossibleWhere(relationFields));
});
} else if (relationPermissions.read !== true) {
const permissionDefinition = await resolvePermissionDefinition(relationPermissions.read, this.context);

return transformValue(actionValue, async (value) => {
return mergeWhereUnique(relationFields, value, permissionDefinition);
return mergeWhereUnique(relationFields, uniqueFields, value, permissionDefinition);
});
}
break;
Expand All @@ -434,7 +453,7 @@ export class ModelResolver {
return transformValue(actionValue, async (value) => {
return {
data: await this.resolveUpdate(relationModelName, value.data),
where: mergeWhereUnique(relationFields, value.where, permissionDefinition),
where: mergeWhereUnique(relationFields, uniqueFields, value.where, permissionDefinition),
};
});
} else {
Expand Down Expand Up @@ -474,7 +493,7 @@ export class ModelResolver {
return {
create,
update,
where: mergeWhereUnique(relationFields, value.where, permissionDefinition),
where: mergeWhereUnique(relationFields, uniqueFields, value.where, permissionDefinition),
};
});
} else {
Expand All @@ -494,7 +513,7 @@ export class ModelResolver {
const permissionDefinition = await resolvePermissionDefinition(relationPermissions.delete, this.context);

return transformValue(actionValue, async (value) => {
return mergeWhereUnique(relationFields, value, permissionDefinition);
return mergeWhereUnique(relationFields, uniqueFields, value, permissionDefinition);
});
}
break;
Expand Down Expand Up @@ -533,7 +552,7 @@ export class ModelResolver {
if (!relationPermissions.read) {
return {
create: await this.resolveCreate(relationModelName, actionValue.create),
where: mergeWhereUnique(relationFields, actionValue.where, generateImpossibleWhere(relationFields)),
where: mergeWhereUnique(relationFields, uniqueFields, actionValue.where, generateImpossibleWhere(relationFields)),
};
} else if (relationPermissions.read !== true) {
const [create, permissionDefinition] = await Promise.all([
Expand All @@ -543,7 +562,7 @@ export class ModelResolver {

return {
create,
where: mergeWhereUnique(relationFields, actionValue.where, permissionDefinition),
where: mergeWhereUnique(relationFields, uniqueFields, actionValue.where, permissionDefinition),
};
} else {
return {
Expand All @@ -554,9 +573,14 @@ export class ModelResolver {
}
case "connect":
if (!relationPermissions.read) {
return mergeWhereUnique(relationFields, actionValue, generateImpossibleWhere(relationFields));
return mergeWhereUnique(relationFields, uniqueFields, actionValue, generateImpossibleWhere(relationFields));
} else if (relationPermissions.read !== true) {
return mergeWhereUnique(relationFields, actionValue, await resolvePermissionDefinition(relationPermissions.read, this.context));
return mergeWhereUnique(
relationFields,
uniqueFields,
actionValue,
await resolvePermissionDefinition(relationPermissions.read, this.context),
);
}
break;
case "disconnect":
Expand Down Expand Up @@ -596,7 +620,7 @@ export class ModelResolver {

return {
data,
where: mergeWhereUnique(relationFields, actionValue.where, permissionDefinition),
where: mergeWhereUnique(relationFields, uniqueFields, actionValue.where, permissionDefinition),
};
}
} else {
Expand Down Expand Up @@ -631,7 +655,7 @@ export class ModelResolver {
return {
create,
update,
where: mergeWhereUnique(relationFields, actionValue.where, permissionDefinition),
where: mergeWhereUnique(relationFields, uniqueFields, actionValue.where, permissionDefinition),
};
}
} else {
Expand Down
3 changes: 3 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export type DMMF = ReadonlyDeep<{
isId: boolean;
type: string;
}[];
uniqueFields: string[][];
}[];
};
}>;
Expand Down Expand Up @@ -77,6 +78,8 @@ export type AllOperationsArgs = { model: PrismaModelName<PrismaTypeMap> } & (

export type FieldsMap = Record<string, Record<string, DMMFField>>;

export type UniqueFieldsMap = Record<string, string[]>;

export type RelationMetadata = {
type: "requiredBelongsTo";
path: string;
Expand Down
12 changes: 8 additions & 4 deletions src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import type { Prisma, PrismaClientExtends } from "@prisma/client/extension";

import { AllOperationsArgs, DMMF, DMMFField, FieldsMap, ObjectEntry } from "./types";
import { AllOperationsArgs, DMMF, DMMFField, FieldsMap, ObjectEntry, UniqueFieldsMap } from "./types";

export function buildFieldsMap(dmmf: DMMF): FieldsMap {
export function buildFieldsMap(dmmf: DMMF): { fieldsMap: FieldsMap; uniqueFields: UniqueFieldsMap } {
const fieldsMap: FieldsMap = {};
const uniqueFields: UniqueFieldsMap = {};

for (const model of dmmf.datamodel.models) {
fieldsMap[model.name] = {};
uniqueFields[model.name] = model.uniqueFields.map((uniqueField) => uniqueField.join("_"));

for (const field of model.fields) {
fieldsMap[model.name][field.name] = field;
}
}

return fieldsMap;
return { fieldsMap, uniqueFields };
}

export function getTransactionClient(prismaClient: PrismaClientExtends, allOperationsArgs: AllOperationsArgs): Prisma.TransactionClient {
Expand Down Expand Up @@ -81,6 +84,7 @@ export function mergeWhere(first: Record<string, unknown> | undefined, second: R

export function mergeWhereUnique(
fields: Record<string, DMMFField>,
uniqueFields: string[],
firstUnique: Record<string, unknown>,
second: Record<string, unknown>,
): Record<string, unknown> {
Expand All @@ -90,7 +94,7 @@ export function mergeWhereUnique(
for (const [fieldName, fieldValue] of Object.entries(firstUnique)) {
const fieldDef = fields[fieldName];

if (isUniqueField(fieldDef)) {
if (uniqueFields.includes(fieldName) || isUniqueField(fieldDef)) {
unique[fieldName] = fieldValue;
} else {
rest[fieldName] = fieldValue;
Expand Down
6 changes: 6 additions & 0 deletions tests/read.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ describe("model reading", () => {
const user = db.user.findUnique({ where: { id: 1 } });
await expect(user).resolves.toEqual(null);
});

test("if read is where with compound unique it return filtered result", async () => {
const db = resolveDb({ Post: { read: { id: { equals: 1 } } } });
const post = db.post.findUnique({ where: { title_categoryId: { title: "Quick bites", categoryId: 1 } } });
await expect(post).resolves.toMatchObject({ id: 1 });
});
});

describe("find unique or throw", () => {
Expand Down
Loading