diff --git a/django_cloneable/models.py b/django_cloneable/models.py index 2c2d221..5328c35 100644 --- a/django_cloneable/models.py +++ b/django_cloneable/models.py @@ -52,6 +52,44 @@ def _clone_attrs(self, duplicate, attrs, exclude=None): for attname, value in attrs.items(): setattr(duplicate, attname, value) + def _clone_copy_fk(self, duplicate, exclude=None): + exclude = exclude or [] + foreign_keys = {} + + for field in self.instance._meta.related_objects: + # Skip this field. + if field.name in exclude: + continue + + # Check for one to many: + if field.one_to_many: + f_name = '%s_set' % field.name + + # Collect the objects which contain ForeignKey pointing to the source object + fks_to_copy = list(getattr(self.instance, f_name).all()) + + for fk in fks_to_copy: + # Empty primary key + fk.pk = None + + # Iterate fields in the classes which contain ForeignKey pointing to the source object. + # If the field has the same object as our source, we should rewrite it's value to point + # to our newly created duplicated record. + for fk_field in fk._meta.fields: + if fk_field.related_model: + if duplicate._meta.object_name == fk_field.related_model._meta.object_name: + setattr(fk, fk_field.name, duplicate) + + try: + # Use fk.__class__ here to avoid hard-coding the class name + foreign_keys[fk.__class__].append(fk) + except KeyError: + foreign_keys[fk.__class__] = [fk] + + # Insert the new records in the database + for cls, list_of_fks in foreign_keys.items(): + cls.objects.bulk_create(list_of_fks) + def _clone_copy_m2m(self, duplicate, exclude=None): exclude = exclude or [] # copy.copy loses all ManyToMany relations. @@ -142,6 +180,11 @@ def clone(self, attrs=None, commit=True, m2m_clone_reverse=True, clone_attrs = getattr(self.instance, '_clone_attrs', self._clone_attrs) clone_attrs(duplicate, attrs, exclude=exclude) + def clone_fk(): + clone_copy_fk = getattr(self.instance, '_clone_copy_fk', + self._clone_copy_fk) + clone_copy_fk(duplicate, exclude=exclude) + def clone_m2m(clone_reverse=m2m_clone_reverse): clone_copy_m2m = getattr(self.instance, '_clone_copy_m2m', self._clone_copy_m2m) @@ -155,8 +198,10 @@ def clone_m2m(clone_reverse=m2m_clone_reverse): if commit: duplicate.save(force_insert=True) + clone_fk() clone_m2m() else: + duplicate.clone_fk = clone_fk duplicate.clone_m2m = clone_m2m return duplicate @@ -203,6 +248,9 @@ def _clone_attrs(self, duplicate, attrs, exclude=None): return self._clone_helper._clone_attrs(duplicate, attrs, exclude=exclude) + def _clone_copy_fk(self, duplicate, exclude=None): + return self._clone_helper._clone_copy_fk(duplicate, exclude=exclude) + def _clone_copy_m2m(self, duplicate, exclude=None): return self._clone_helper._clone_copy_m2m(duplicate, exclude=exclude) @@ -219,6 +267,9 @@ def clone(self, attrs=None, commit=True, m2m_clone_reverse=True, self._clone_prepare(duplicate, exclude=exclude) self._clone_attrs(duplicate, attrs, exclude=exclude) + def clone_fk(): + self._clone_copy_fk(duplicate, exclude=exclude) + def clone_m2m(clone_reverse=m2m_clone_reverse): self._clone_copy_m2m(duplicate, exclude=exclude) if clone_reverse: @@ -227,8 +278,10 @@ def clone_m2m(clone_reverse=m2m_clone_reverse): if commit: duplicate.save(force_insert=True) + clone_fk() clone_m2m() else: + duplicate.clone_fk = clone_fk duplicate.clone_m2m = clone_m2m return duplicate