Skip to content

Commit f6c14ff

Browse files
JCoxwellbckohan
andcommitted
Add create_from_super method and test
Co-authored-by: Joshua <[email protected]> Co-authored-by: Brian Kohan <[email protected]>
1 parent 447e8c2 commit f6c14ff

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

src/polymorphic/managers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
The manager class for use in the models.
33
"""
44

5+
from django.contrib.contenttypes.models import ContentType
56
from django.db import models
67

78
from polymorphic.query import PolymorphicQuerySet
@@ -49,3 +50,35 @@ def not_instance_of(self, *args):
4950

5051
def get_real_instances(self, base_result_objects=None):
5152
return self.all().get_real_instances(base_result_objects=base_result_objects)
53+
54+
def create_from_super(self, obj, **kwargs):
55+
"""Creates an instance of self.model (cls) from existing super class.
56+
The new subclass will be the same object with same database id
57+
and data as obj, but will be an instance of cls.
58+
59+
obj must be an instance of the direct superclass of cls.
60+
kwargs should contain all required fields of the subclass (cls).
61+
62+
returns obj as an instance of cls.
63+
"""
64+
cls = self.model
65+
import inspect
66+
67+
scls = inspect.getmro(cls)[1]
68+
if scls is not type(obj):
69+
raise Exception(
70+
"create_from_super can only be used if obj is one level of inheritance up from cls"
71+
)
72+
ptr = "{}_ptr_id".format(scls.__name__.lower())
73+
kwargs[ptr] = obj.id
74+
# create the new base class with only fields that apply to it.
75+
nobj = cls(**kwargs)
76+
nobj.save_base(raw=True)
77+
# force update the content type, but first we need to
78+
# retrieve a clean copy from the db to fill in the null
79+
# fields otherwise they would be overwritten.
80+
nobj = obj.__class__.objects.get(pk=obj.pk)
81+
nobj.polymorphic_ctype = ContentType.objects.get_for_model(cls)
82+
nobj.save()
83+
84+
return nobj.get_real_instance() # cast to cls
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from django.test import TransactionTestCase
2+
from polymorphic.tests.models import Model2A, Model2B, Model2C, Model2D
3+
4+
5+
class PolymorphicTests(TransactionTestCase):
6+
def test_create_from_super(self):
7+
# run create test 3 times because initial implementation
8+
# would fail after first success.
9+
for i in range(3):
10+
mc = Model2C.objects.create(
11+
field1="C1{}".format(i), field2="C2{}".format(i), field3="C3{}".format(i)
12+
)
13+
mc.save()
14+
field4 = "D4{}".format(i)
15+
md = Model2D.objects.create_from_super(mc, field4=field4)
16+
self.assertEqual(mc.id, md.id)
17+
self.assertEqual(mc.field1, md.field1)
18+
self.assertEqual(mc.field2, md.field2)
19+
self.assertEqual(mc.field3, md.field3)
20+
self.assertEqual(md.field4, field4)
21+
ma = Model2A.objects.create(field1="A1e")
22+
self.assertRaises(Exception, Model2D.objects.create_from_super, ma, field4="D4e")
23+
mb = Model2B.objects.create(field1="B1e", field2="B2e")
24+
self.assertRaises(Exception, Model2D.objects.create_from_super, mb, field4="D4e")

0 commit comments

Comments
 (0)