diff --git a/nibabel/cifti2/cifti2.py b/nibabel/cifti2/cifti2.py
index ec303642eb..3c47a3aa75 100644
--- a/nibabel/cifti2/cifti2.py
+++ b/nibabel/cifti2/cifti2.py
@@ -20,8 +20,9 @@
 from collections.abc import MutableSequence, MutableMapping, Iterable
 from collections import OrderedDict
 from .. import xmlutils as xml
-from ..filebasedimages import FileBasedHeader
+from ..filebasedimages import FileBasedHeader, SerializableImage
 from ..dataobj_images import DataobjImage
+from ..nifti1 import Nifti1Extensions
 from ..nifti2 import Nifti2Image, Nifti2Header
 from ..arrayproxy import reshape_dataobj
 from warnings import warn
@@ -1255,6 +1256,9 @@ def _to_xml_element(self):
             cifti.append(mat_xml)
         return cifti
 
+    def __eq__(self, other):
+        return self.to_xml() == other.to_xml()
+
     @classmethod
     def may_contain_header(klass, binaryblock):
         from .parse_cifti2 import _Cifti2AsNiftiHeader
@@ -1326,7 +1330,7 @@ def from_axes(cls, axes):
         return cifti2_axes.to_header(axes)
 
 
-class Cifti2Image(DataobjImage):
+class Cifti2Image(DataobjImage, SerializableImage):
     """ Class for single file CIFTI-2 format image
     """
     header_class = Cifti2Header
@@ -1454,6 +1458,9 @@ def to_file_map(self, file_map=None):
         self.update_headers()
         header = self._nifti_header
         extension = Cifti2Extension(content=self.header.to_xml())
+        header.extensions = Nifti1Extensions(
+            ext for ext in header.extensions if not isinstance(ext, Cifti2Extension)
+        )
         header.extensions.append(extension)
         if self._dataobj.shape != self.header.matrix.get_data_shape():
             raise ValueError(
diff --git a/nibabel/cifti2/tests/test_cifti2.py b/nibabel/cifti2/tests/test_cifti2.py
index 0d3d550a66..ea571065de 100644
--- a/nibabel/cifti2/tests/test_cifti2.py
+++ b/nibabel/cifti2/tests/test_cifti2.py
@@ -12,6 +12,7 @@
 import pytest
 
 from nibabel.tests.test_dataobj_images import TestDataobjAPI as _TDA
+from nibabel.tests.test_image_api import SerializeMixin
 
 
 def compare_xml_leaf(str1, str2):
@@ -406,7 +407,7 @@ def test_underscoring():
         assert ci.cifti2._underscore(camel) == underscored
 
 
-class TestCifti2ImageAPI(_TDA):
+class TestCifti2ImageAPI(_TDA, SerializeMixin):
     """ Basic validation for Cifti2Image instances
     """
     # A callable returning an image from ``image_maker(data, header)``