Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,9 @@ def execute_sql(self, result_type):
f"{field.__class__.__name__}."
)
prepared = field.get_db_prep_save(value, connection=self.connection)
if hasattr(value, "as_mql"):
if is_direct_value(value):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still question whether the $literal escaping is overly broad. Better safe than sorry, sure, but it decreases readability a bit and makes queries larger. But I'm not sure how to compare the tradeoffs between some more complicated logic (isinstance() CPU time) and leaving it as is. It seems if we did wrapping of only the types that Value() wraps, it should be safe, unless the logic in Value() is deficient. And are there any string values besides those that start with $ that could be problematic? Perhaps to be discussed in chat tomorrow.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 I agree. Seeing $literal everywhere is annoying, but Value's resolver is probably covering the most common types used in queries. On the other hand, implementing the same escaping logic that Value uses is not a big deal.

prepared = {"$literal": prepared}
else:
prepared = prepared.as_mql(self, self.connection, as_expr=True)
values[field.column] = prepared
try:
Expand Down
6 changes: 3 additions & 3 deletions django_mongodb_backend/expressions/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def when(self, compiler, connection):

def value(self, compiler, connection, as_expr=False): # noqa: ARG001
value = self.value
if isinstance(value, (list, int)) and as_expr:
# Wrap lists & numbers in $literal to prevent ambiguity when Value
# appears in $project.
if isinstance(value, (list, int, str, dict, tuple)) and as_expr:
# Wrap lists, numbers, strings, dicts, and tuples in $literal to avoid
# ambiguity when Value is used in aggregate() or update_many()'s $set.
return {"$literal": value}
if isinstance(value, Decimal):
return Decimal128(value)
Expand Down
30 changes: 26 additions & 4 deletions django_mongodb_backend/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class MongoTestCaseMixin:
maxDiff = None
query_types = {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}

def assertAggregateQuery(self, query, expected_collection, expected_pipeline):
"""
Expand All @@ -15,7 +16,28 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline):
_, collection, operator = prefix.split(".")
self.assertEqual(operator, "aggregate")
self.assertEqual(collection, expected_collection)
self.assertEqual(
eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}), # noqa: S307
expected_pipeline,
)
self.assertEqual(eval(pipeline[:-1], self.query_types, {}), expected_pipeline) # noqa: S307

def assertInsertQuery(self, query, expected_collection, expected_documents):
"""
Assert that the logged query is equal to:
db.{expected_collection}.insert_many({expected_documents})
"""
prefix, pipeline = query.split("(", 1)
_, collection, operator = prefix.split(".")
self.assertEqual(operator, "insert_many")
self.assertEqual(collection, expected_collection)
self.assertEqual(eval(pipeline[:-1], self.query_types), expected_documents) # noqa: S307

def assertUpdateQuery(self, query, expected_collection, expected_condition, expected_set):
"""
Assert that the logged query is equal to:
db.{expected_collection}.update_many({expected_condition}, {expected_set})
"""
prefix, pipeline = query.split("(", 1)
_, collection, operator = prefix.split(".")
self.assertEqual(operator, "update_many")
self.assertEqual(collection, expected_collection)
condition, set_expression = eval(pipeline[:-1], self.query_types, {}) # noqa: S307
self.assertEqual(condition, expected_condition)
self.assertEqual(set_expression, expected_set)
5 changes: 5 additions & 0 deletions docs/releases/5.2.x.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ Bug fixes

- Prevented ``QuerySet.union()`` queries from duplicating the ``$project``
pipeline.
- Made :class:`~django.db.models.Value` wrap strings in ``$literal`` to
prevent dollar-prefixed strings from being interpreted as expressions.
Also wrapped dictionaries and tuples to prevent the same for them.
- Made model update queries wrap values in ``$literal`` to prevent values from
being interpreted as expressions.

Performance improvements
------------------------
Expand Down
Empty file added tests/basic_/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions tests/basic_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from django.db import models


class Author(models.Model):
name = models.CharField(max_length=50)

def __str__(self):
return self.name


class Blob(models.Model):
name = models.CharField(max_length=10)
data = models.JSONField(null=True)

def __str__(self):
return self.name
100 changes: 100 additions & 0 deletions tests/basic_/test_escaping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Literals that MongoDB intreprets as expressions are escaped."""

from operator import attrgetter

from django.db.models import Value
from django.test import TestCase

from django_mongodb_backend.test import MongoTestCaseMixin

from .models import Author, Blob


class ModelCreationTests(MongoTestCaseMixin, TestCase):
def test_dollar_prefixed_string(self):
# No escaping is needed because MongoDB's insert doesn't treat
# dollar-prefixed strings as expressions.
with self.assertNumQueries(1) as ctx:
obj = Author.objects.create(name="$foobar")
obj.refresh_from_db()
self.assertEqual(obj.name, "$foobar")
self.assertInsertQuery(
ctx.captured_queries[0]["sql"], "basic__author", [{"name": "$foobar"}]
)


class ModelUpdateTests(MongoTestCaseMixin, TestCase):
"""
$-prefixed strings and dict/tuples that could be interpreted as expressions
are escaped in the queries that update model instances.
"""

def test_dollar_prefixed_string(self):
obj = Author.objects.create(name="foobar")
obj.name = "$updated"
with self.assertNumQueries(1) as ctx:
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.name, "$updated")
self.assertUpdateQuery(
ctx.captured_queries[0]["sql"],
"basic__author",
{"_id": obj.id},
[{"$set": {"name": {"$literal": "$updated"}}}],
)

def test_dollar_prefixed_value(self):
obj = Author.objects.create(name="foobar")
obj.name = Value("$updated")
with self.assertNumQueries(1) as ctx:
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.name, "$updated")
self.assertUpdateQuery(
ctx.captured_queries[0]["sql"],
"basic__author",
{"_id": obj.id},
[{"$set": {"name": {"$literal": "$updated"}}}],
)

def test_dict(self):
obj = Blob.objects.create(name="foobar")
obj.data = {"$concat": ["$name", "-", "$name"]}
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.data, {"$concat": ["$name", "-", "$name"]})

def test_dict_value(self):
obj = Blob.objects.create(name="foobar", data={})
obj.data = Value({"$concat": ["$name", "-", "$name"]})
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.data, {"$concat": ["$name", "-", "$name"]})

def test_tuple(self):
obj = Blob.objects.create(name="foobar")
obj.data = ("$name", "-", "$name")
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.data, ["$name", "-", "$name"])

def test_tuple_value(self):
obj = Blob.objects.create(name="foobar")
obj.data = Value(("$name", "-", "$name"))
obj.save()
obj.refresh_from_db()
self.assertEqual(obj.data, ["$name", "-", "$name"])


class AnnotationTests(MongoTestCaseMixin, TestCase):
def test_dollar_prefixed_value(self):
"""Value() escapes dollar prefixed strings."""
Author.objects.create(name="Gustavo")
with self.assertNumQueries(1) as ctx:
qs = list(Author.objects.annotate(a_value=Value("$name")))
self.assertQuerySetEqual(qs, ["$name"], attrgetter("a_value"))
self.assertAggregateQuery(
ctx.captured_queries[0]["sql"],
"basic__author",
[{"$project": {"a_value": {"$literal": "$name"}, "_id": 1, "name": 1}}],
)
13 changes: 13 additions & 0 deletions tests/expressions_/test_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def test_datetime(self):
def test_decimal(self):
self.assertEqual(Value(Decimal("1.0")).as_mql(None, None), Decimal128("1.0"))

def test_dict_expr(self):
self.assertEqual(
Value({"$foo": "$bar"}).as_mql(None, None, as_expr=True), {"$literal": {"$foo": "$bar"}}
)

def test_list(self):
self.assertEqual(Value([1, 2]).as_mql(None, None, as_expr=True), {"$literal": [1, 2]})

Expand All @@ -41,6 +46,14 @@ def test_int(self):
def test_str(self):
self.assertEqual(Value("foo").as_mql(None, None), "foo")

def test_str_expr(self):
self.assertEqual(Value("$foo").as_mql(None, None, as_expr=True), {"$literal": "$foo"})

def test_tuple_expr(self):
self.assertEqual(
Value(("$foo", "$bar")).as_mql(None, None, as_expr=True), {"$literal": ("$foo", "$bar")}
)

def test_uuid(self):
value = uuid.UUID(int=1)
self.assertEqual(Value(value).as_mql(None, None), "00000000000000000000000000000001")
14 changes: 12 additions & 2 deletions tests/model_fields_/test_embedded_model_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,18 @@ def test_nested_array_index_expr(self):
},
{
"$concat": [
{"$ifNull": ["Z", ""]},
{"$ifNull": ["acarias", ""]},
{
"$ifNull": [
{"$literal": "Z"},
{"$literal": ""},
]
},
{
"$ifNull": [
{"$literal": "acarias"},
{"$literal": ""},
]
},
]
},
]
Expand Down
Loading