Skip to content

Commit be8054e

Browse files
authored
feat: add max length for encrypted string (#290)
* feat: add max length for encrypted string * fix: updated exceptions and added tests * fix: skip mocks
1 parent 28c918e commit be8054e

File tree

7 files changed

+98
-34
lines changed

7 files changed

+98
-34
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
- id: unasyncd
2323
additional_dependencies: ["ruff"]
2424
- repo: https://github.com/charliermarsh/ruff-pre-commit
25-
rev: "v0.7.3"
25+
rev: "v0.7.4"
2626
hooks:
2727
# Run the linter.
2828
- id: ruff

advanced_alchemy/exceptions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
import re
55
from contextlib import contextmanager
6-
from typing import Any, Callable, Generator, TypedDict, Union
6+
from typing import Any, Callable, Generator, TypedDict, Union, cast
77

88
from sqlalchemy.exc import IntegrityError as SQLAlchemyIntegrityError
99
from sqlalchemy.exc import InvalidRequestError as SQLAlchemyInvalidRequestError
10-
from sqlalchemy.exc import MultipleResultsFound, SQLAlchemyError
10+
from sqlalchemy.exc import MultipleResultsFound, SQLAlchemyError, StatementError
1111

1212
from advanced_alchemy.utils.deprecation import deprecated
1313

@@ -291,6 +291,7 @@ def wrap_sqlalchemy_exception(
291291
"""
292292
try:
293293
yield
294+
294295
except MultipleResultsFound as exc:
295296
if error_messages is not None:
296297
msg = _get_error_message(error_messages=error_messages, key="multiple_rows", exc=exc)
@@ -318,6 +319,10 @@ def wrap_sqlalchemy_exception(
318319
raise IntegrityError(detail=f"An integrity error occurred: {exc}") from exc
319320
except SQLAlchemyInvalidRequestError as exc:
320321
raise InvalidRequestError(detail="An invalid request was made.") from exc
322+
except StatementError as exc:
323+
raise IntegrityError(
324+
detail=cast(str, getattr(exc.orig, "detail", "There was an issue processing the statement."))
325+
) from exc
321326
except SQLAlchemyError as exc:
322327
if error_messages is not None:
323328
msg = _get_error_message(error_messages=error_messages, key="other", exc=exc)

advanced_alchemy/types/encrypted_string.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from sqlalchemy import String, Text, TypeDecorator
1010
from sqlalchemy import func as sql_func
1111

12+
from advanced_alchemy.exceptions import IntegrityError
13+
1214
if TYPE_CHECKING:
1315
from sqlalchemy.engine import Dialect
1416

@@ -222,11 +224,13 @@ class EncryptedString(TypeDecorator[str]):
222224
Args:
223225
key (str | bytes | Callable[[], str | bytes] | None): The encryption key. Can be a string, bytes, or callable returning either. Defaults to os.urandom(32).
224226
backend (Type[EncryptionBackend] | None): The encryption backend class to use. Defaults to FernetBackend.
227+
length (int | None): The length of the unencrypted string. This is used for documentation and validation purposes only, as encrypted strings will be longer.
225228
**kwargs (Any | None): Additional arguments passed to the underlying String type.
226229
227230
Attributes:
228231
key (str | bytes | Callable[[], str | bytes]): The encryption key.
229232
backend (EncryptionBackend): The encryption backend instance.
233+
length (int | None): The unencrypted string length.
230234
"""
231235

232236
impl = String
@@ -236,18 +240,21 @@ def __init__(
236240
self,
237241
key: str | bytes | Callable[[], str | bytes] = os.urandom(32),
238242
backend: type[EncryptionBackend] = FernetBackend,
243+
length: int | None = None,
239244
**kwargs: Any,
240245
) -> None:
241246
"""Initializes the EncryptedString TypeDecorator.
242247
243248
Args:
244249
key (str | bytes | Callable[[], str | bytes] | None): The encryption key. Can be a string, bytes, or callable returning either. Defaults to os.urandom(32).
245250
backend (Type[EncryptionBackend] | None): The encryption backend class to use. Defaults to FernetBackend.
251+
length (int | None): The length of the unencrypted string. This is used for documentation and validation purposes only.
246252
**kwargs (Any | None): Additional arguments passed to the underlying String type.
247253
"""
248254
super().__init__()
249255
self.key = key
250256
self.backend = backend()
257+
self.length = length
251258

252259
@property
253260
def python_type(self) -> type[str]:
@@ -261,32 +268,46 @@ def python_type(self) -> type[str]:
261268
def load_dialect_impl(self, dialect: Dialect) -> Any:
262269
"""Loads the appropriate dialect implementation based on the database dialect.
263270
271+
Note: The actual column length will be larger than the specified length due to encryption overhead.
272+
For most encryption methods, the encrypted string will be approximately 1.35x longer than the original.
273+
264274
Args:
265275
dialect (Dialect): The SQLAlchemy dialect.
266276
267277
Returns:
268278
Any: The dialect-specific type descriptor.
269279
"""
270280
if dialect.name in {"mysql", "mariadb"}:
281+
# For MySQL/MariaDB, always use Text to avoid length limitations
271282
return dialect.type_descriptor(Text())
272283
if dialect.name == "oracle":
284+
# Oracle has a 4000-byte limit for VARCHAR2 (by default)
273285
return dialect.type_descriptor(String(length=4000))
274286
return dialect.type_descriptor(String())
275287

276288
def process_bind_param(self, value: Any, dialect: Dialect) -> str | None:
277289
"""Processes the value before binding it to the SQL statement.
278290
279-
This method encrypts the value using the specified backend.
291+
This method encrypts the value using the specified backend and validates length if specified.
280292
281293
Args:
282294
value (Any): The value to process.
283295
dialect (Dialect): The SQLAlchemy dialect.
284296
285297
Returns:
286298
str | None: The encrypted value or None if the input is None.
299+
300+
Raises:
301+
ValueError: If the value exceeds the specified length.
287302
"""
288303
if value is None:
289304
return value
305+
306+
# Validate length if specified
307+
if self.length is not None and len(str(value)) > self.length:
308+
msg = f"Unencrypted value exceeds maximum unencrypted length of {self.length}"
309+
raise IntegrityError(msg)
310+
290311
self.mount_vault()
291312
return self.backend.encrypt(value)
292313

tests/fixtures/bigint/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,7 @@ class BigIntSecret(BigIntBase):
109109
long_secret: Mapped[str] = mapped_column(
110110
EncryptedText(key="super_secret"),
111111
)
112+
length_validated_secret: Mapped[str] = mapped_column(
113+
EncryptedString(key="super_secret", length=10),
114+
nullable=True,
115+
)

tests/fixtures/uuid/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ class UUIDSecret(UUIDv7Base):
7373
long_secret: Mapped[str] = mapped_column(
7474
EncryptedText(key="super_secret"),
7575
)
76+
length_validated_secret: Mapped[str] = mapped_column(
77+
EncryptedString(key="super_secret", length=10),
78+
nullable=True,
79+
)
7680

7781

7882
class UUIDModelWithFetchedValue(UUIDv6Base):

tests/integration/test_repository.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
ModelWithFetchedValueRepository = SQLAlchemyAsyncRepository[AnyModelWithFetchedValue]
111111
ModelWithFetchedValueService = SQLAlchemyAsyncRepositoryService[AnyModelWithFetchedValue]
112112

113+
113114
RawRecordData = List[Dict[str, Any]]
114115

115116
mock_engines = {"mock_async_engine", "mock_sync_engine"}
@@ -1937,6 +1938,35 @@ async def test_repo_encrypted_methods(
19371938
assert obj.long_secret == updated.long_secret
19381939

19391940

1941+
async def test_encrypted_string_length_validation(
1942+
request: FixtureRequest, secret_repo: SecretRepository, secret_model: SecretModel
1943+
) -> None:
1944+
"""Test that EncryptedString enforces length validation.
1945+
1946+
Args:
1947+
secret_repo: The secret repository
1948+
secret_model: The secret model class
1949+
"""
1950+
if any(fixture in request.fixturenames for fixture in ["mock_async_engine", "mock_sync_engine"]):
1951+
pytest.skip(
1952+
f"{SQLAlchemyAsyncMockRepository.__name__} does not works with client side validated encrypted strings lengths"
1953+
)
1954+
# Test valid length
1955+
valid_secret = "AAAAAAAAA"
1956+
secret = secret_model(secret="test", long_secret="test", length_validated_secret=valid_secret)
1957+
saved_secret = await maybe_async(secret_repo.add(secret))
1958+
assert saved_secret.length_validated_secret == valid_secret
1959+
1960+
# Test exceeding length
1961+
long_secret = "A" * 51 # Exceeds 50 character limit
1962+
with pytest.raises(IntegrityError) as exc_info:
1963+
secret = secret_model(secret="test", long_secret="test", length_validated_secret=long_secret)
1964+
await maybe_async(secret_repo.add(secret))
1965+
1966+
assert exc_info.value.__class__.__name__ == "IntegrityError"
1967+
assert "exceeds maximum unencrypted length" in str(exc_info.value.detail)
1968+
1969+
19401970
# service tests
19411971
async def test_service_filter_search(author_service: AuthorService) -> None:
19421972
existing_obj = await maybe_async(

0 commit comments

Comments
 (0)