diff --git a/onedal/svm/__init__.py b/onedal/svm/__init__.py
index 6bcf140a4a..edff432c69 100644
--- a/onedal/svm/__init__.py
+++ b/onedal/svm/__init__.py
@@ -14,6 +14,6 @@
 # limitations under the License.
 # ==============================================================================
 
-from .svm import SVC, SVR, NuSVC, NuSVR, SVMtype
+from .svm import SVC, SVR, NuSVC, NuSVR
 
-__all__ = ["SVC", "SVR", "NuSVC", "NuSVR", "SVMtype"]
+__all__ = ["SVC", "SVR", "NuSVC", "NuSVR"]
diff --git a/onedal/svm/svm.py b/onedal/svm/svm.py
index f4184a40ac..f6732edda0 100644
--- a/onedal/svm/svm.py
+++ b/onedal/svm/svm.py
@@ -15,13 +15,13 @@
 # ==============================================================================
 
 from abc import ABCMeta, abstractmethod
-from enum import Enum
 
 import numpy as np
 from scipy import sparse as sp
 
 from onedal import _backend
 
+from ..common._base import BaseEstimator
 from ..common._estimator_checks import _check_is_fitted
 from ..common._mixin import ClassifierMixin, RegressorMixin
 from ..common._policy import _get_policy
@@ -35,14 +35,7 @@
 )
 
 
-class SVMtype(Enum):
-    c_svc = 0
-    epsilon_svr = 1
-    nu_svc = 2
-    nu_svr = 3
-
-
-class BaseSVM(metaclass=ABCMeta):
+class BaseSVM(BaseEstimator, metaclass=ABCMeta):
     @abstractmethod
     def __init__(
         self,
@@ -63,8 +56,6 @@ def __init__(
         decision_function_shape,
         break_ties,
         algorithm,
-        svm_type=None,
-        **kwargs,
     ):
         self.C = C
         self.nu = nu
@@ -82,21 +73,20 @@ def __init__(
         self.decision_function_shape = decision_function_shape
         self.break_ties = break_ties
         self.algorithm = algorithm
-        self.svm_type = svm_type
 
     def _validate_targets(self, y, dtype):
         self.class_weight_ = None
         self.classes_ = None
         return _column_or_1d(y, warn=True).astype(dtype, copy=False)
 
-    def _get_onedal_params(self, data):
+    def _get_onedal_params(self, dtype):
         max_iter = 10000 if self.max_iter == -1 else self.max_iter
         # TODO: remove this workaround
         # when oneDAL SVM starts support of 'n_iterations' result
         self.n_iter_ = 1 if max_iter < 1 else max_iter
         class_count = 0 if self.classes_ is None else len(self.classes_)
         return {
-            "fptype": data.dtype,
+            "fptype": dtype,
             "method": self.algorithm,
             "kernel": self.kernel,
             "c": self.C,
@@ -129,6 +119,7 @@ def _fit(self, X, y, sample_weight, module, queue):
             force_all_finite=True,
             accept_sparse="csr",
         )
+        # hard work remains on moving validate targets away from onedal
         y = self._validate_targets(y, X.dtype)
         if sample_weight is not None and len(sample_weight) > 0:
             sample_weight = _check_array(
@@ -154,29 +145,12 @@ def _fit(self, X, y, sample_weight, module, queue):
             self._scale_, self._sigma_ = 1.0, 1.0
             self.coef0 = 0.0
         else:
-            if isinstance(self.gamma, str):
-                if self.gamma == "scale":
-                    if sp.issparse(X):
-                        # var = E[X^2] - E[X]^2
-                        X_sc = (X.multiply(X)).mean() - (X.mean()) ** 2
-                    else:
-                        X_sc = X.var()
-                    _gamma = 1.0 / (X.shape[1] * X_sc) if X_sc != 0 else 1.0
-                elif self.gamma == "auto":
-                    _gamma = 1.0 / X.shape[1]
-                else:
-                    raise ValueError(
-                        "When 'gamma' is a string, it should be either 'scale' or "
-                        "'auto'. Got '{}' instead.".format(self.gamma)
-                    )
-            else:
-                _gamma = self.gamma
-            self._scale_, self._sigma_ = _gamma, np.sqrt(0.5 / _gamma)
+            self._scale_, self._sigma_ = self.gamma, np.sqrt(0.5 / self.gamma)
 
         policy = _get_policy(queue, *data)
-        X = _convert_to_supported(policy, X)
-        params = self._get_onedal_params(X)
-        result = module.train(policy, params, *to_table(*data))
+        data_t = to_table(*_convert_to_supported(policy, *data))
+        params = self._get_onedal_params(data_t[0].dtype)
+        result = module.train(policy, params, *data_t)
 
         if self._sparse:
             self.dual_coef_ = sp.csr_matrix(from_table(result.coeffs).T)
@@ -190,6 +164,7 @@ def _fit(self, X, y, sample_weight, module, queue):
         self.n_features_in_ = X.shape[1]
         self.shape_fit_ = X.shape
 
+        # _n_support not used in this object, will be moved to sklearnex
         if getattr(self, "classes_", None) is not None:
             indices = y.take(self.support_, axis=0)
             self._n_support = np.array(
@@ -206,128 +181,37 @@ def _create_model(self, module):
         m.support_vectors = to_table(self.support_vectors_)
         m.coeffs = to_table(self.dual_coef_.T)
         m.biases = to_table(self.intercept_)
-
-        if self.svm_type is SVMtype.c_svc or self.svm_type is SVMtype.nu_svc:
-            m.first_class_response, m.second_class_response = 0, 1
         return m
 
-    def _predict(self, X, module, queue):
+    def _infer(self, X, module, queue):
         _check_is_fitted(self)
-        if self.break_ties and self.decision_function_shape == "ovo":
-            raise ValueError(
-                "break_ties must be False when " "decision_function_shape is 'ovo'"
-            )
-
-        if module in [_backend.svm.classification, _backend.svm.nu_classification]:
-            sv = self.support_vectors_
-            if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]:
-                raise ValueError(
-                    "The internal representation "
-                    f"of {self.__class__.__name__} was altered"
-                )
-
-        if (
-            self.break_ties
-            and self.decision_function_shape == "ovr"
-            and len(self.classes_) > 2
-        ):
-            y = np.argmax(self.decision_function(X), axis=1)
-        else:
-            X = _check_array(
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=True,
-                accept_sparse="csr",
-            )
-            _check_n_features(self, X, False)
 
-            if self._sparse and not sp.isspmatrix(X):
+        if self._sparse:
+            if not sp.isspmatrix(X):
                 X = sp.csr_matrix(X)
-            if self._sparse:
-                X.sort_indices()
-
-            if sp.issparse(X) and not self._sparse and not callable(self.kernel):
-                raise ValueError(
-                    "cannot use sparse input in %r trained on dense data"
-                    % type(self).__name__
-                )
-
-            policy = _get_policy(queue, X)
-            X = _convert_to_supported(policy, X)
-            params = self._get_onedal_params(X)
-
-            if hasattr(self, "_onedal_model"):
-                model = self._onedal_model
             else:
-                model = self._create_model(module)
-            result = module.infer(policy, params, model, to_table(X))
-            y = from_table(result.responses)
-        return y
-
-    def _ovr_decision_function(self, predictions, confidences, n_classes):
-        n_samples = predictions.shape[0]
-        votes = np.zeros((n_samples, n_classes))
-        sum_of_confidences = np.zeros((n_samples, n_classes))
-
-        k = 0
-        for i in range(n_classes):
-            for j in range(i + 1, n_classes):
-                sum_of_confidences[:, i] -= confidences[:, k]
-                sum_of_confidences[:, j] += confidences[:, k]
-                votes[predictions[:, k] == 0, i] += 1
-                votes[predictions[:, k] == 1, j] += 1
-                k += 1
-
-        transformed_confidences = sum_of_confidences / (
-            3 * (np.abs(sum_of_confidences) + 1)
-        )
-        return votes + transformed_confidences
-
-    def _decision_function(self, X, module, queue):
-        _check_is_fitted(self)
-        X = _check_array(
-            X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse="csr"
-        )
-        _check_n_features(self, X, False)
-
-        if self._sparse and not sp.isspmatrix(X):
-            X = sp.csr_matrix(X)
-        if self._sparse:
-            X.sort_indices()
-
-        if sp.issparse(X) and not self._sparse and not callable(self.kernel):
+                X.sort_indices()
+        elif sp.issparse(X) and not callable(self.kernel):
             raise ValueError(
                 "cannot use sparse input in %r trained on dense data"
                 % type(self).__name__
             )
 
-        if module in [_backend.svm.classification, _backend.svm.nu_classification]:
-            sv = self.support_vectors_
-            if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]:
-                raise ValueError(
-                    "The internal representation "
-                    f"of {self.__class__.__name__} was altered"
-                )
-
         policy = _get_policy(queue, X)
-        X = _convert_to_supported(policy, X)
-        params = self._get_onedal_params(X)
+        X = to_table(_convert_to_supported(policy, X))
+        params = self._get_onedal_params(X.dtype)
 
         if hasattr(self, "_onedal_model"):
             model = self._onedal_model
         else:
             model = self._create_model(module)
-        result = module.infer(policy, params, model, to_table(X))
-        decision_function = from_table(result.decision_function)
+        return module.infer(policy, params, model, X)
 
-        if len(self.classes_) == 2:
-            decision_function = decision_function.ravel()
+    def _predict(self, X, module, queue):
+        return from_table(self._infer(X, module, queue).responses)
 
-        if self.decision_function_shape == "ovr" and len(self.classes_) > 2:
-            decision_function = self._ovr_decision_function(
-                decision_function < 0, -decision_function, len(self.classes_)
-            )
-        return decision_function
+    def _decision_function(self, X, module, queue):
+        return from_table(self._infer(X, module, queue).decision_function)
 
 
 class SVR(RegressorMixin, BaseSVM):
@@ -350,7 +234,6 @@ def __init__(
         max_iter=-1,
         tau=1e-12,
         algorithm="thunder",
-        **kwargs,
     ):
         super().__init__(
             C=C,
@@ -370,14 +253,12 @@ def __init__(
             break_ties=False,
             algorithm=algorithm,
         )
-        self.svm_type = SVMtype.epsilon_svr
 
     def fit(self, X, y, sample_weight=None, queue=None):
         return super()._fit(X, y, sample_weight, _backend.svm.regression, queue)
 
     def predict(self, X, queue=None):
-        y = super()._predict(X, _backend.svm.regression, queue)
-        return y.ravel()
+        return super()._predict(X, _backend.svm.regression, queue)
 
 
 class SVC(ClassifierMixin, BaseSVM):
@@ -402,7 +283,6 @@ def __init__(
         decision_function_shape="ovr",
         break_ties=False,
         algorithm="thunder",
-        **kwargs,
     ):
         super().__init__(
             C=C,
@@ -422,7 +302,11 @@ def __init__(
             break_ties=break_ties,
             algorithm=algorithm,
         )
-        self.svm_type = SVMtype.c_svc
+
+    def _create_model(self, module):
+        m = super()._create_model(module)
+        m.first_class_response, m.second_class_response = 0, 1
+        return m
 
     def _validate_targets(self, y, dtype):
         y, self.class_weight_, self.classes_ = _validate_targets(
@@ -434,10 +318,7 @@ def fit(self, X, y, sample_weight=None, queue=None):
         return super()._fit(X, y, sample_weight, _backend.svm.classification, queue)
 
     def predict(self, X, queue=None):
-        y = super()._predict(X, _backend.svm.classification, queue)
-        if len(self.classes_) == 2:
-            y = y.ravel()
-        return self.classes_.take(np.asarray(y, dtype=np.intp)).ravel()
+        return super()._predict(X, _backend.svm.classification, queue)
 
     def decision_function(self, X, queue=None):
         return super()._decision_function(X, _backend.svm.classification, queue)
@@ -463,7 +344,6 @@ def __init__(
         max_iter=-1,
         tau=1e-12,
         algorithm="thunder",
-        **kwargs,
     ):
         super().__init__(
             C=C,
@@ -483,14 +363,12 @@ def __init__(
             break_ties=False,
             algorithm=algorithm,
         )
-        self.svm_type = SVMtype.nu_svr
 
     def fit(self, X, y, sample_weight=None, queue=None):
         return super()._fit(X, y, sample_weight, _backend.svm.nu_regression, queue)
 
     def predict(self, X, queue=None):
-        y = super()._predict(X, _backend.svm.nu_regression, queue)
-        return y.ravel()
+        return super()._predict(X, _backend.svm.nu_regression, queue)
 
 
 class NuSVC(ClassifierMixin, BaseSVM):
@@ -515,7 +393,6 @@ def __init__(
         decision_function_shape="ovr",
         break_ties=False,
         algorithm="thunder",
-        **kwargs,
     ):
         super().__init__(
             C=1.0,
@@ -535,7 +412,11 @@ def __init__(
             break_ties=break_ties,
             algorithm=algorithm,
         )
-        self.svm_type = SVMtype.nu_svc
+
+    def _create_model(self, module):
+        m = super()._create_model(module)
+        m.first_class_response, m.second_class_response = 0, 1
+        return m
 
     def _validate_targets(self, y, dtype):
         y, self.class_weight_, self.classes_ = _validate_targets(
@@ -547,10 +428,7 @@ def fit(self, X, y, sample_weight=None, queue=None):
         return super()._fit(X, y, sample_weight, _backend.svm.nu_classification, queue)
 
     def predict(self, X, queue=None):
-        y = super()._predict(X, _backend.svm.nu_classification, queue)
-        if len(self.classes_) == 2:
-            y = y.ravel()
-        return self.classes_.take(np.asarray(y, dtype=np.intp)).ravel()
+        return super()._predict(X, _backend.svm.nu_classification, queue)
 
     def decision_function(self, X, queue=None):
         return super()._decision_function(X, _backend.svm.nu_classification, queue)
diff --git a/onedal/svm/tests/test_csr_svm.py b/onedal/svm/tests/test_csr_svm.py
index e4a05a030e..04bcd77d75 100644
--- a/onedal/svm/tests/test_csr_svm.py
+++ b/onedal/svm/tests/test_csr_svm.py
@@ -61,12 +61,12 @@ def check_svm_model_equal(
 
 
 def _test_simple_dataset(queue, kernel):
-    X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]])
-    sparse_X = sp.lil_matrix(X)
-    Y = [1, 1, 1, 2, 2, 2]
+    X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=np.float64)
+    sparse_X = sp.csr_matrix(X)
+    Y = np.array([1, 1, 1, 2, 2, 2], dtype=np.float64)
 
-    X2 = np.array([[-1, -1], [2, 2], [3, 2]])
-    sparse_X2 = sp.dok_matrix(X2)
+    X2 = np.array([[-1, -1], [2, 2], [3, 2]], dtype=np.float64)
+    sparse_X2 = sp.csr_matrix(X2)
 
     dataset = sparse_X, Y, sparse_X2
     clf0 = SVC(kernel=kernel, gamma=1)
@@ -93,80 +93,6 @@ def test_simple_dataset(queue, kernel):
     _test_simple_dataset(queue, kernel)
 
 
-def _test_binary_dataset(queue, kernel):
-    X, y = make_classification(n_samples=80, n_features=20, n_classes=2, random_state=0)
-    sparse_X = sp.csr_matrix(X)
-
-    dataset = sparse_X, y, sparse_X
-    clf0 = SVC(kernel=kernel)
-    clf1 = SVC(kernel=kernel)
-    check_svm_model_equal(queue, clf0, clf1, *dataset)
-
-
-@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented")
-@pytest.mark.parametrize(
-    "queue",
-    get_queues("cpu")
-    + [
-        pytest.param(
-            get_queues("gpu"),
-            marks=pytest.mark.xfail(
-                reason="raises UnknownError for linear and rbf, "
-                "Unimplemented error with inconsistent error message "
-                "for poly and sigmoid"
-            ),
-        )
-    ],
-)
-@pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"])
-def test_binary_dataset(queue, kernel):
-    _test_binary_dataset(queue, kernel)
-
-
-def _test_iris(queue, kernel):
-    iris = datasets.load_iris()
-    rng = np.random.RandomState(0)
-    perm = rng.permutation(iris.target.size)
-    iris.data = iris.data[perm]
-    iris.target = iris.target[perm]
-    sparse_iris_data = sp.csr_matrix(iris.data)
-
-    dataset = sparse_iris_data, iris.target, sparse_iris_data
-
-    clf0 = SVC(kernel=kernel)
-    clf1 = SVC(kernel=kernel)
-    check_svm_model_equal(queue, clf0, clf1, *dataset, decimal=2)
-
-
-@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented")
-@pytest.mark.parametrize("queue", get_queues())
-@pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"])
-def test_iris(queue, kernel):
-    if kernel == "rbf":
-        pytest.skip("RBF CSR SVM test failing in 2025.0.")
-    _test_iris(queue, kernel)
-
-
-def _test_diabetes(queue, kernel):
-    diabetes = datasets.load_diabetes()
-
-    sparse_diabetes_data = sp.csr_matrix(diabetes.data)
-    dataset = sparse_diabetes_data, diabetes.target, sparse_diabetes_data
-
-    clf0 = SVR(kernel=kernel, C=0.1)
-    clf1 = SVR(kernel=kernel, C=0.1)
-    check_svm_model_equal(queue, clf0, clf1, *dataset)
-
-
-@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented")
-@pytest.mark.parametrize("queue", get_queues())
-@pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"])
-def test_diabetes(queue, kernel):
-    if kernel == "sigmoid":
-        pytest.skip("Sparse sigmoid kernel function is buggy.")
-    _test_diabetes(queue, kernel)
-
-
 @pass_if_not_implemented_for_gpu(reason="csr svm is not implemented")
 @pytest.mark.xfail(reason="Failed test. Need investigate")
 @pytest.mark.parametrize("queue", get_queues())
diff --git a/onedal/svm/tests/test_nusvc.py b/onedal/svm/tests/test_nusvc.py
index c8bf99a9d3..692a9d78c4 100644
--- a/onedal/svm/tests/test_nusvc.py
+++ b/onedal/svm/tests/test_nusvc.py
@@ -32,7 +32,7 @@
 
 def _test_libsvm_parameters(queue, array_constr, dtype):
     X = array_constr([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=dtype)
-    y = array_constr([1, 1, 1, 2, 2, 2], dtype=dtype)
+    y = array_constr([0, 0, 0, 1, 1, 1], dtype=dtype)
 
     clf = NuSVC(kernel="linear").fit(X, y, queue=queue)
     assert_array_almost_equal(
@@ -41,7 +41,7 @@ def _test_libsvm_parameters(queue, array_constr, dtype):
     assert_array_equal(clf.support_, [0, 1, 3, 4])
     assert_array_equal(clf.support_vectors_, X[clf.support_])
     assert_array_equal(clf.intercept_, [0.0])
-    assert_array_equal(clf.predict(X, queue=queue), y)
+    assert_array_equal(clf.predict(X, queue=queue).ravel(), y)
 
 
 @pass_if_not_implemented_for_gpu(reason="nusvc is not implemented")
@@ -55,12 +55,12 @@ def test_libsvm_parameters(queue, array_constr, dtype):
 @pass_if_not_implemented_for_gpu(reason="nusvc is not implemented")
 @pytest.mark.parametrize("queue", get_queues())
 def test_class_weight(queue):
-    X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]])
-    y = np.array([1, 1, 1, 2, 2, 2])
+    X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=np.float64)
+    y = np.array([0, 0, 0, 1, 1, 1], dtype=np.float64)
 
-    clf = NuSVC(class_weight={1: 0.1})
+    clf = NuSVC(class_weight={0: 0.1})
     clf.fit(X, y, queue=queue)
-    assert_array_almost_equal(clf.predict(X, queue=queue), [2] * 6)
+    assert_array_almost_equal(clf.predict(X, queue=queue).ravel(), [1] * 6)
 
 
 @pass_if_not_implemented_for_gpu(reason="nusvc is not implemented")
@@ -77,15 +77,15 @@ def test_sample_weight(queue):
 @pass_if_not_implemented_for_gpu(reason="nusvc is not implemented")
 @pytest.mark.parametrize("queue", get_queues())
 def test_decision_function(queue):
-    X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
-    Y = [1, 1, 1, 2, 2, 2]
+    X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=np.float32)
+    Y = np.array([1, 1, 1, 2, 2, 2], dtype=np.float32)
 
     clf = NuSVC(kernel="rbf", gamma=1, decision_function_shape="ovo")
     clf.fit(X, Y, queue=queue)
 
     rbfs = rbf_kernel(X, clf.support_vectors_, gamma=clf.gamma)
     dec = np.dot(rbfs, clf.dual_coef_.T) + clf.intercept_
-    assert_array_almost_equal(dec.ravel(), clf.decision_function(X, queue=queue))
+    assert_array_almost_equal(dec.ravel(), clf.decision_function(X, queue=queue).ravel())
 
 
 @pass_if_not_implemented_for_gpu(reason="nusvc is not implemented")
diff --git a/onedal/svm/tests/test_nusvr.py b/onedal/svm/tests/test_nusvr.py
index 1bec991961..bfd40c5767 100644
--- a/onedal/svm/tests/test_nusvr.py
+++ b/onedal/svm/tests/test_nusvr.py
@@ -196,7 +196,7 @@ def test_synth_poly_compare_with_sklearn(queue, params):
 def test_pickle(queue):
     diabetes = datasets.load_diabetes()
 
-    clf = NuSVR(kernel="rbf", C=10.0)
+    clf = NuSVR(kernel="linear", C=10.0)
     clf.fit(diabetes.data, diabetes.target, queue=queue)
     expected = clf.predict(diabetes.data, queue=queue)
 
diff --git a/onedal/svm/tests/test_svc.py b/onedal/svm/tests/test_svc.py
index 9f7eaa4810..8e148383fa 100644
--- a/onedal/svm/tests/test_svc.py
+++ b/onedal/svm/tests/test_svc.py
@@ -32,14 +32,14 @@
 
 def _test_libsvm_parameters(queue, array_constr, dtype):
     X = array_constr([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=dtype)
-    y = array_constr([1, 1, 1, 2, 2, 2], dtype=dtype)
+    y = array_constr([0, 0, 0, 1, 1, 1], dtype=dtype)
 
     clf = SVC(kernel="linear").fit(X, y, queue=queue)
     assert_array_equal(clf.dual_coef_, [[-0.25, 0.25]])
     assert_array_equal(clf.support_, [1, 3])
     assert_array_equal(clf.support_vectors_, (X[1], X[3]))
     assert_array_equal(clf.intercept_, [0.0])
-    assert_array_equal(clf.predict(X), y)
+    assert_array_equal(clf.predict(X).ravel(), y)
 
 
 @pytest.mark.parametrize("queue", get_queues())
@@ -65,12 +65,12 @@ def test_libsvm_parameters(queue, array_constr, dtype):
     ],
 )
 def test_class_weight(queue):
-    X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]])
-    y = np.array([1, 1, 1, 2, 2, 2])
+    X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=np.float64)
+    y = np.array([0, 0, 0, 1, 1, 1], dtype=np.float64)
 
-    clf = SVC(class_weight={1: 0.1})
+    clf = SVC(class_weight={0: 0.1})
     clf.fit(X, y, queue=queue)
-    assert_array_almost_equal(clf.predict(X, queue=queue), [2] * 6)
+    assert_array_almost_equal(clf.predict(X, queue=queue).ravel(), [1] * 6)
 
 
 @pytest.mark.parametrize("queue", get_queues())
@@ -95,7 +95,7 @@ def test_decision_function(queue):
 
     rbfs = rbf_kernel(X, clf.support_vectors_, gamma=clf.gamma)
     dec = np.dot(rbfs, clf.dual_coef_.T) + clf.intercept_
-    assert_array_almost_equal(dec.ravel(), clf.decision_function(X, queue=queue))
+    assert_array_almost_equal(dec.ravel(), clf.decision_function(X, queue=queue).ravel())
 
 
 @pass_if_not_implemented_for_gpu(reason="multiclass svm is not implemented")
@@ -160,9 +160,9 @@ def test_svc_sigmoid(queue, dtype):
         [[-1, 2], [0, 0], [2, -1], [+1, +1], [+1, +2], [+2, +1]], dtype=dtype
     )
     X_test = np.array([[0, 2], [0.5, 0.5], [0.3, 0.1], [2, 0], [-1, -1]], dtype=dtype)
-    y_train = np.array([1, 1, 1, 2, 2, 2], dtype=dtype)
+    y_train = np.array([0, 0, 0, 1, 1, 1], dtype=dtype)
     svc = SVC(kernel="sigmoid").fit(X_train, y_train, queue=queue)
 
     assert_array_equal(svc.dual_coef_, [[-1, -1, -1, 1, 1, 1]])
     assert_array_equal(svc.support_, [0, 1, 2, 3, 4, 5])
-    assert_array_equal(svc.predict(X_test, queue=queue), [2, 2, 1, 2, 1])
+    assert_array_equal(svc.predict(X_test, queue=queue).ravel(), [1, 1, 0, 1, 0])
diff --git a/onedal/svm/tests/test_svr.py b/onedal/svm/tests/test_svr.py
index a9000ff5f7..82ce86a78c 100644
--- a/onedal/svm/tests/test_svr.py
+++ b/onedal/svm/tests/test_svr.py
@@ -206,22 +206,24 @@ def test_synth_poly_compare_with_sklearn(queue, params):
 def test_sided_sample_weight(queue):
     clf = SVR(C=1e-2, kernel="linear")
 
-    X = [[-2, 0], [-1, -1], [0, -2], [0, 2], [1, 1], [2, 0]]
-    Y = [1, 1, 1, 2, 2, 2]
+    X = np.array([[-2, 0], [-1, -1], [0, -2], [0, 2], [1, 1], [2, 0]], dtype=np.float64)
+    Y = np.array([1, 1, 1, 2, 2, 2], dtype=np.float64)
 
-    sample_weight = [10.0, 0.1, 0.1, 0.1, 0.1, 10]
+    X_pred = np.array([[-1.0, 1.0]], dtype=np.float64)
+
+    sample_weight = np.array([10.0, 0.1, 0.1, 0.1, 0.1, 10], dtype=np.float64)
     clf.fit(X, Y, sample_weight=sample_weight, queue=queue)
-    y_pred = clf.predict([[-1.0, 1.0]], queue=queue)
+    y_pred = clf.predict(X_pred, queue=queue)
     assert y_pred < 1.5
 
-    sample_weight = [1.0, 0.1, 10.0, 10.0, 0.1, 0.1]
+    sample_weight = np.array([1.0, 0.1, 10.0, 10.0, 0.1, 0.1], dtype=np.float64)
     clf.fit(X, Y, sample_weight=sample_weight, queue=queue)
-    y_pred = clf.predict([[-1.0, 1.0]], queue=queue)
+    y_pred = clf.predict(X_pred, queue=queue)
     assert y_pred > 1.5
 
-    sample_weight = [1] * 6
+    sample_weight = np.array([1] * 6, dtype=np.float64)
     clf.fit(X, Y, sample_weight=sample_weight, queue=queue)
-    y_pred = clf.predict([[-1.0, 1.0]], queue=queue)
+    y_pred = clf.predict(X_pred, queue=queue)
     assert y_pred == pytest.approx(1.5)
 
 
@@ -229,7 +231,7 @@ def test_sided_sample_weight(queue):
 @pytest.mark.parametrize("queue", get_queues())
 def test_pickle(queue):
     diabetes = datasets.load_diabetes()
-    clf = SVR(kernel="rbf", C=10.0)
+    clf = SVR(kernel="linear", C=10.0)
     clf.fit(diabetes.data, diabetes.target, queue=queue)
     expected = clf.predict(diabetes.data, queue=queue)
 
diff --git a/sklearnex/svm/_common.py b/sklearnex/svm/_common.py
index 4b481314ae..da7ab39927 100644
--- a/sklearnex/svm/_common.py
+++ b/sklearnex/svm/_common.py
@@ -20,16 +20,25 @@
 
 import numpy as np
 from scipy import sparse as sp
-from sklearn.base import BaseEstimator, ClassifierMixin
+from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
 from sklearn.calibration import CalibratedClassifierCV
-from sklearn.metrics import r2_score
+from sklearn.exceptions import NotFittedError
+from sklearn.metrics import accuracy_score, r2_score
 from sklearn.preprocessing import LabelEncoder
+from sklearn.svm._base import BaseLibSVM as _sklearn_BaseLibSVM
+from sklearn.svm._base import BaseSVC as _sklearn_BaseSVC
+from sklearn.utils.validation import check_array, check_is_fitted
 
 from daal4py.sklearn._utils import sklearn_check_version
 from onedal.utils import _check_array, _check_X_y, _column_or_1d
 
 from .._config import config_context, get_config
+from .._device_offload import dispatch, wrap_output_data
 from .._utils import PatchingConditionsChain
+from ..utils._array_api import get_namespace
+
+if sklearn_check_version("1.0"):
+    from sklearn.utils.metaestimators import available_if
 
 if sklearn_check_version("1.6"):
     from sklearn.utils.validation import validate_data
@@ -37,7 +46,9 @@
     validate_data = BaseEstimator._validate_data
 
 
-class BaseSVM(BaseEstimator, ABC):
+class BaseSVM(object):
+
+    _onedal_factory = None
 
     @property
     def _dual_coef_(self):
@@ -94,7 +105,7 @@ def _onedal_cpu_supported(self, method_name, *data):
             return patching_status
         inference_methods = (
             ["predict", "score"]
-            if class_name.endswith("R")
+            if isinstance(self, RegressorMixin)
             else ["predict", "predict_proba", "decision_function", "score"]
         )
         if method_name in inference_methods:
@@ -238,8 +249,168 @@ def _get_sample_weight(self, X, y, sample_weight):
 
         return sample_weight
 
+    def _onedal_predict(self, X, queue=None, xp=None):
+        if xp is None:
+            xp, _ = get_namespace(X)
+
+        if sklearn_check_version("1.0"):
+            X = validate_data(
+                self,
+                X,
+                dtype=[xp.float64, xp.float32],
+                accept_sparse="csr",
+                reset=False,
+            )
+        else:
+            X = check_array(
+                X,
+                dtype=[xp.float64, xp.float32],
+                accept_sparse="csr",
+            )
+
+        return self._onedal_estimator.predict(X, queue=queue)
+
 
 class BaseSVC(BaseSVM):
+
+    @wrap_output_data
+    def predict(self, X):
+        check_is_fitted(self)
+        return dispatch(
+            self,
+            "predict",
+            {
+                "onedal": self.__class__._onedal_predict,
+                "sklearn": _sklearn_BaseSVC.predict,
+            },
+            X,
+        )
+
+    @wrap_output_data
+    def score(self, X, y, sample_weight=None):
+        check_is_fitted(self)
+        return dispatch(
+            self,
+            "score",
+            {
+                "onedal": self.__class__._onedal_score,
+                "sklearn": _sklearn_BaseSVC.score,
+            },
+            X,
+            y,
+            sample_weight=sample_weight,
+        )
+
+    @wrap_output_data
+    def decision_function(self, X):
+        check_is_fitted(self)
+        return dispatch(
+            self,
+            "decision_function",
+            {
+                "onedal": self.__class__._onedal_decision_function,
+                "sklearn": _sklearn_BaseSVC.decision_function,
+            },
+            X,
+        )
+
+    if sklearn_check_version("1.0"):
+
+        @available_if(_sklearn_BaseSVC._check_proba)
+        def predict_proba(self, X):
+            """
+            Compute probabilities of possible outcomes for samples in X.
+
+            The model need to have probability information computed at training
+            time: fit with attribute `probability` set to True.
+
+            Parameters
+            ----------
+            X : array-like of shape (n_samples, n_features)
+                For kernel="precomputed", the expected shape of X is
+                (n_samples_test, n_samples_train).
+
+            Returns
+            -------
+            T : ndarray of shape (n_samples, n_classes)
+                Returns the probability of the sample for each class in
+                the model. The columns correspond to the classes in sorted
+                order, as they appear in the attribute :term:`classes_`.
+
+            Notes
+            -----
+            The probability model is created using cross validation, so
+            the results can be slightly different than those obtained by
+            predict. Also, it will produce meaningless results on very small
+            datasets.
+            """
+            check_is_fitted(self)
+            return self._predict_proba(X)
+
+        @available_if(_sklearn_BaseSVC._check_proba)
+        def predict_log_proba(self, X):
+            """Compute log probabilities of possible outcomes for samples in X.
+
+            The model need to have probability information computed at training
+            time: fit with attribute `probability` set to True.
+
+            Parameters
+            ----------
+            X : array-like of shape (n_samples, n_features) or \
+                    (n_samples_test, n_samples_train)
+                For kernel="precomputed", the expected shape of X is
+                (n_samples_test, n_samples_train).
+
+            Returns
+            -------
+            T : ndarray of shape (n_samples, n_classes)
+                Returns the log-probabilities of the sample for each class in
+                the model. The columns correspond to the classes in sorted
+                order, as they appear in the attribute :term:`classes_`.
+
+            Notes
+            -----
+            The probability model is created using cross validation, so
+            the results can be slightly different than those obtained by
+            predict. Also, it will produce meaningless results on very small
+            datasets.
+            """
+            xp, _ = get_namespace(X)
+
+            return xp.log(self.predict_proba(X))
+
+    else:
+
+        @property
+        def predict_proba(self):
+            self._check_proba()
+            check_is_fitted(self)
+            return self._predict_proba
+
+        def _predict_log_proba(self, X):
+            xp, _ = get_namespace(X)
+            return xp.log(self.predict_proba(X))
+
+        predict_proba.__doc__ = _sklearn_BaseSVC.predict_proba.__doc__
+
+    @wrap_output_data
+    def _predict_proba(self, X):
+        sklearn_pred_proba = (
+            _sklearn_BaseSVC.predict_proba
+            if sklearn_check_version("1.0")
+            else _sklearn_BaseSVC._predict_proba
+        )
+
+        return dispatch(
+            self,
+            "predict_proba",
+            {
+                "onedal": self.__class__._onedal_predict_proba,
+                "sklearn": sklearn_pred_proba,
+            },
+            X,
+        )
+
     def _compute_balanced_class_weight(self, y):
         y_ = _column_or_1d(y)
         classes, _ = np.unique(y_, return_inverse=True)
@@ -289,7 +460,7 @@ def _save_attributes(self):
         self.dual_coef_ = self._onedal_estimator.dual_coef_
         self.shape_fit_ = self._onedal_estimator.class_weight_
         self.classes_ = self._onedal_estimator.classes_
-        if isinstance(self, ClassifierMixin) or not sklearn_check_version("1.2"):
+        if not sklearn_check_version("1.2"):
             self.class_weight_ = self._onedal_estimator.class_weight_
         self.support_ = self._onedal_estimator.support_
 
@@ -311,8 +482,145 @@ def _save_attributes(self):
             length = int(len(self.classes_) * (len(self.classes_) - 1) / 2)
             self.n_iter_ = np.full((length,), self._onedal_estimator.n_iter_)
 
+    def _onedal_predict(self, X, queue=None):
+        sv = self.support_vectors_
+        if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]:
+            raise ValueError(
+                "The internal representation " f"of {self.__class__.__name__} was altered"
+            )
+
+        xp, _ = get_namespace(X)
+
+        if self.break_ties and self.decision_function_shape == "ovo":
+            raise ValueError(
+                "break_ties must be False when " "decision_function_shape is 'ovo'"
+            )
+
+        if (
+            self.break_ties
+            and self.decision_function_shape == "ovr"
+            and len(self.classes_) > 2
+        ):
+            res = xp.argmax(self._onedal_decision_function(X, queue=queue), axis=1)
+        else:
+            res = super()._onedal_predict(X, queue=queue, xp=xp)
+
+        # the extensive reshaping here comes from the previous implementation, and
+        # should be sorted out, as this is inefficient and likely can be reduced
+        res = xp.asarray(res, dtype=xp.int32)
+        if len(self.classes_) == 2:
+            res = xp.reshape(res, (-1,))
+
+        return xp.reshape(xp.take(xp.asarray(self.classes_), res), (-1,))
+
+    def _onedal_ovr_decision_function(self, predictions, confidences, n_classes):
+        # This function is legacy from the original implementation and needs
+        # to be refactored.
+        xp, _ = get_namespace(predictions)
+        n_samples = predictions.shape[0]
+        votes = xp.zeros((n_samples, n_classes))
+        sum_of_confidences = xp.zeros((n_samples, n_classes))
+
+        k = 0
+        for i in range(n_classes):
+            for j in range(i + 1, n_classes):
+                sum_of_confidences[:, i] -= confidences[:, k]
+                sum_of_confidences[:, j] += confidences[:, k]
+                votes[predictions[:, k] == 0, i] += 1
+                votes[predictions[:, k] == 1, j] += 1
+                k += 1
+
+        transformed_confidences = sum_of_confidences / (
+            3 * (xp.abs(sum_of_confidences) + 1)
+        )
+        return votes + transformed_confidences
+
+    def _onedal_decision_function(self, X, queue=None):
+        sv = self.support_vectors_
+        if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]:
+            raise ValueError(
+                "The internal representation " f"of {self.__class__.__name__} was altered"
+            )
+        xp, _ = get_namespace(X)
+        if sklearn_check_version("1.0"):
+            X = validate_data(
+                self,
+                X,
+                dtype=[xp.float64, xp.float32],
+                accept_sparse="csr",
+                reset=False,
+            )
+        else:
+            X = check_array(
+                X,
+                dtype=[xp.float64, xp.float32],
+                accept_sparse="csr",
+            )
+
+        decision_function = self._onedal_estimator.decision_function(X, queue=queue)
+
+        if len(self.classes_) == 2:
+            decision_function = decision_function.ravel()
+        elif len(self.classes_) > 2 and self.decision_function_shape == "ovr":
+            decision_function = self._onedal_ovr_decision_function(
+                decision_function < 0, -decision_function, len(self.classes_)
+            )
+
+        return xp.asarray(decision_function)
+
+    def _onedal_predict_proba(self, X, queue=None):
+        if getattr(self, "clf_prob", None) is None:
+            raise NotFittedError(
+                "predict_proba is not available when fitted with probability=False"
+            )
+        from .._config import config_context, get_config
+
+        # We use stock metaestimators below, so the only way
+        # to pass a queue is using config_context.
+        cfg = get_config()
+        cfg["target_offload"] = queue
+        with config_context(**cfg):
+            return self.clf_prob.predict_proba(X)
+
+    def _onedal_score(self, X, y, sample_weight=None, queue=None):
+        return accuracy_score(
+            y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
+        )
+
+    predict.__doc__ = _sklearn_BaseSVC.predict.__doc__
+    decision_function.__doc__ = _sklearn_BaseSVC.decision_function.__doc__
+    score.__doc__ = _sklearn_BaseSVC.score.__doc__
+
 
 class BaseSVR(BaseSVM):
+    @wrap_output_data
+    def predict(self, X):
+        check_is_fitted(self)
+        return dispatch(
+            self,
+            "predict",
+            {
+                "onedal": self.__class__._onedal_predict,
+                "sklearn": _sklearn_BaseLibSVM.predict,
+            },
+            X,
+        )
+
+    @wrap_output_data
+    def score(self, X, y, sample_weight=None):
+        check_is_fitted(self)
+        return dispatch(
+            self,
+            "score",
+            {
+                "onedal": self.__class__._onedal_score,
+                "sklearn": RegressorMixin.score,
+            },
+            X,
+            y,
+            sample_weight=sample_weight,
+        )
+
     def _save_attributes(self):
         self.support_vectors_ = self._onedal_estimator.support_vectors_
         self.n_features_in_ = self._onedal_estimator.n_features_in_
@@ -333,7 +641,15 @@ def _save_attributes(self):
 
         self._dualcoef_ = self.dual_coef_
 
+    def _onedal_predict(self, X, queue=None):
+        xp, _ = get_namespace(X)
+        res = super()._onedal_predict(X, queue)
+        return xp.reshape(res, (-1,))
+
     def _onedal_score(self, X, y, sample_weight=None, queue=None):
         return r2_score(
             y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
         )
+
+    predict.__doc__ = _sklearn_BaseLibSVM.predict.__doc__
+    score.__doc__ = RegressorMixin.score.__doc__
diff --git a/sklearnex/svm/nusvc.py b/sklearnex/svm/nusvc.py
index 301d90ccc4..b53cb71f63 100644
--- a/sklearnex/svm/nusvc.py
+++ b/sklearnex/svm/nusvc.py
@@ -15,8 +15,6 @@
 # ==============================================================================
 
 import numpy as np
-from sklearn.exceptions import NotFittedError
-from sklearn.metrics import accuracy_score
 from sklearn.svm import NuSVC as _sklearn_NuSVC
 from sklearn.utils.validation import (
     _deprecate_positional_args,
@@ -26,27 +24,23 @@
 
 from daal4py.sklearn._n_jobs_support import control_n_jobs
 from daal4py.sklearn._utils import sklearn_check_version
+from onedal.svm import NuSVC as onedal_NuSVC
 
 from .._device_offload import dispatch, wrap_output_data
-from ..utils._array_api import get_namespace
 from ._common import BaseSVC
 
-if sklearn_check_version("1.0"):
-    from sklearn.utils.metaestimators import available_if
-
-from onedal.svm import NuSVC as onedal_NuSVC
-
 if sklearn_check_version("1.6"):
     from sklearn.utils.validation import validate_data
 else:
-    validate_data = BaseSVC._validate_data
+    validate_data = _sklearn_NuSVC._validate_data
 
 
 @control_n_jobs(
     decorated_methods=["fit", "predict", "_predict_proba", "decision_function", "score"]
 )
-class NuSVC(_sklearn_NuSVC, BaseSVC):
+class NuSVC(BaseSVC, _sklearn_NuSVC):
     __doc__ = _sklearn_NuSVC.__doc__
+    _onedal_factory = onedal_NuSVC
 
     if sklearn_check_version("1.2"):
         _parameter_constraints: dict = {**_sklearn_NuSVC._parameter_constraints}
@@ -117,146 +111,6 @@ def fit(self, X, y, sample_weight=None):
 
         return self
 
-    @wrap_output_data
-    def predict(self, X):
-        check_is_fitted(self)
-        return dispatch(
-            self,
-            "predict",
-            {
-                "onedal": self.__class__._onedal_predict,
-                "sklearn": _sklearn_NuSVC.predict,
-            },
-            X,
-        )
-
-    @wrap_output_data
-    def score(self, X, y, sample_weight=None):
-        check_is_fitted(self)
-        return dispatch(
-            self,
-            "score",
-            {
-                "onedal": self.__class__._onedal_score,
-                "sklearn": _sklearn_NuSVC.score,
-            },
-            X,
-            y,
-            sample_weight=sample_weight,
-        )
-
-    if sklearn_check_version("1.0"):
-
-        @available_if(_sklearn_NuSVC._check_proba)
-        def predict_proba(self, X):
-            """
-            Compute probabilities of possible outcomes for samples in X.
-
-            The model need to have probability information computed at training
-            time: fit with attribute `probability` set to True.
-
-            Parameters
-            ----------
-            X : array-like of shape (n_samples, n_features)
-                For kernel="precomputed", the expected shape of X is
-                (n_samples_test, n_samples_train).
-
-            Returns
-            -------
-            T : ndarray of shape (n_samples, n_classes)
-                Returns the probability of the sample for each class in
-                the model. The columns correspond to the classes in sorted
-                order, as they appear in the attribute :term:`classes_`.
-
-            Notes
-            -----
-            The probability model is created using cross validation, so
-            the results can be slightly different than those obtained by
-            predict. Also, it will produce meaningless results on very small
-            datasets.
-            """
-            check_is_fitted(self)
-            return self._predict_proba(X)
-
-        @available_if(_sklearn_NuSVC._check_proba)
-        def predict_log_proba(self, X):
-            """Compute log probabilities of possible outcomes for samples in X.
-
-            The model need to have probability information computed at training
-            time: fit with attribute `probability` set to True.
-
-            Parameters
-            ----------
-            X : array-like of shape (n_samples, n_features) or \
-                    (n_samples_test, n_samples_train)
-                For kernel="precomputed", the expected shape of X is
-                (n_samples_test, n_samples_train).
-
-            Returns
-            -------
-            T : ndarray of shape (n_samples, n_classes)
-                Returns the log-probabilities of the sample for each class in
-                the model. The columns correspond to the classes in sorted
-                order, as they appear in the attribute :term:`classes_`.
-
-            Notes
-            -----
-            The probability model is created using cross validation, so
-            the results can be slightly different than those obtained by
-            predict. Also, it will produce meaningless results on very small
-            datasets.
-            """
-            xp, _ = get_namespace(X)
-
-            return xp.log(self.predict_proba(X))
-
-    else:
-
-        @property
-        def predict_proba(self):
-            self._check_proba()
-            check_is_fitted(self)
-            return self._predict_proba
-
-        def _predict_log_proba(self, X):
-            xp, _ = get_namespace(X)
-            return xp.log(self.predict_proba(X))
-
-        predict_proba.__doc__ = _sklearn_NuSVC.predict_proba.__doc__
-
-    @wrap_output_data
-    def _predict_proba(self, X):
-        sklearn_pred_proba = (
-            _sklearn_NuSVC.predict_proba
-            if sklearn_check_version("1.0")
-            else _sklearn_NuSVC._predict_proba
-        )
-
-        return dispatch(
-            self,
-            "predict_proba",
-            {
-                "onedal": self.__class__._onedal_predict_proba,
-                "sklearn": sklearn_pred_proba,
-            },
-            X,
-        )
-
-    @wrap_output_data
-    def decision_function(self, X):
-        check_is_fitted(self)
-        return dispatch(
-            self,
-            "decision_function",
-            {
-                "onedal": self.__class__._onedal_decision_function,
-                "sklearn": _sklearn_NuSVC.decision_function,
-            },
-            X,
-        )
-
-    decision_function.__doc__ = _sklearn_NuSVC.decision_function.__doc__
-
     def _get_sample_weight(self, X, y, sample_weight=None):
         sample_weight = super()._get_sample_weight(X, y, sample_weight)
         if sample_weight is None:
@@ -292,7 +146,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
             "decision_function_shape": self.decision_function_shape,
         }
 
-        self._onedal_estimator = onedal_NuSVC(**onedal_params)
+        self._onedal_estimator = self._onedal_factory(**onedal_params)
         self._onedal_estimator.fit(X, y, weights, queue=queue)
 
         if self.probability:
@@ -305,67 +159,4 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
 
         self._save_attributes()
 
-    def _onedal_predict(self, X, queue=None):
-        if sklearn_check_version("1.0"):
-            validate_data(
-                self,
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=False,
-                ensure_2d=False,
-                accept_sparse="csr",
-                reset=False,
-            )
-        else:
-            X = check_array(
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=False,
-                accept_sparse="csr",
-            )
-
-        return self._onedal_estimator.predict(X, queue=queue)
-
-    def _onedal_predict_proba(self, X, queue=None):
-        if getattr(self, "clf_prob", None) is None:
-            raise NotFittedError(
-                "predict_proba is not available when fitted with probability=False"
-            )
-        from .._config import config_context, get_config
-
-        # We use stock metaestimators below, so the only way
-        # to pass a queue is using config_context.
-        cfg = get_config()
-        cfg["target_offload"] = queue
-        with config_context(**cfg):
-            return self.clf_prob.predict_proba(X)
-
-    def _onedal_decision_function(self, X, queue=None):
-        if sklearn_check_version("1.0"):
-            validate_data(
-                self,
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=False,
-                accept_sparse="csr",
-                reset=False,
-            )
-        else:
-            X = check_array(
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=False,
-                accept_sparse="csr",
-            )
-
-        return self._onedal_estimator.decision_function(X, queue=queue)
-
-    def _onedal_score(self, X, y, sample_weight=None, queue=None):
-        return accuracy_score(
-            y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
-        )
-
     fit.__doc__ = _sklearn_NuSVC.fit.__doc__
-    predict.__doc__ = _sklearn_NuSVC.predict.__doc__
-    decision_function.__doc__ = _sklearn_NuSVC.decision_function.__doc__
-    score.__doc__ = _sklearn_NuSVC.score.__doc__
diff --git a/sklearnex/svm/nusvr.py b/sklearnex/svm/nusvr.py
index 6c746174ac..557d55bcb5 100644
--- a/sklearnex/svm/nusvr.py
+++ b/sklearnex/svm/nusvr.py
@@ -32,12 +32,13 @@
 if sklearn_check_version("1.6"):
     from sklearn.utils.validation import validate_data
 else:
-    validate_data = BaseSVR._validate_data
+    validate_data = _sklearn_NuSVR._validate_data
 
 
 @control_n_jobs(decorated_methods=["fit", "predict", "score"])
-class NuSVR(_sklearn_NuSVR, BaseSVR):
+class NuSVR(BaseSVR, _sklearn_NuSVR):
     __doc__ = _sklearn_NuSVR.__doc__
+    _onedal_factory = onedal_NuSVR
 
     if sklearn_check_version("1.2"):
         _parameter_constraints: dict = {**_sklearn_NuSVR._parameter_constraints}
@@ -99,34 +100,6 @@ def fit(self, X, y, sample_weight=None):
         )
         return self
 
-    @wrap_output_data
-    def predict(self, X):
-        check_is_fitted(self)
-        return dispatch(
-            self,
-            "predict",
-            {
-                "onedal": self.__class__._onedal_predict,
-                "sklearn": _sklearn_NuSVR.predict,
-            },
-            X,
-        )
-
-    @wrap_output_data
-    def score(self, X, y, sample_weight=None):
-        check_is_fitted(self)
-        return dispatch(
-            self,
-            "score",
-            {
-                "onedal": self.__class__._onedal_score,
-                "sklearn": _sklearn_NuSVR.score,
-            },
-            X,
-            y,
-            sample_weight=sample_weight,
-        )
-
     def _onedal_fit(self, X, y, sample_weight=None, queue=None):
         X, _, sample_weight = self._onedal_fit_checks(X, y, sample_weight)
         onedal_params = {
@@ -142,29 +115,8 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
             "max_iter": self.max_iter,
         }
 
-        self._onedal_estimator = onedal_NuSVR(**onedal_params)
+        self._onedal_estimator = self._onedal_factory(**onedal_params)
         self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
         self._save_attributes()
 
-    def _onedal_predict(self, X, queue=None):
-        if sklearn_check_version("1.0"):
-            X = validate_data(
-                self,
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=False,
-                accept_sparse="csr",
-                reset=False,
-            )
-        else:
-            X = check_array(
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=False,
-                accept_sparse="csr",
-            )
-        return self._onedal_estimator.predict(X, queue=queue)
-
     fit.__doc__ = _sklearn_NuSVR.fit.__doc__
-    predict.__doc__ = _sklearn_NuSVR.predict.__doc__
-    score.__doc__ = _sklearn_NuSVR.score.__doc__
diff --git a/sklearnex/svm/svc.py b/sklearnex/svm/svc.py
index bf5e7f32fc..e87085bd17 100644
--- a/sklearnex/svm/svc.py
+++ b/sklearnex/svm/svc.py
@@ -16,7 +16,6 @@
 
 import numpy as np
 from scipy import sparse as sp
-from sklearn.exceptions import NotFittedError
 from sklearn.metrics import accuracy_score
 from sklearn.svm import SVC as _sklearn_SVC
 from sklearn.utils.validation import (
@@ -41,14 +40,15 @@
 if sklearn_check_version("1.6"):
     from sklearn.utils.validation import validate_data
 else:
-    validate_data = BaseSVC._validate_data
+    validate_data = _sklearn_SVC._validate_data
 
 
 @control_n_jobs(
     decorated_methods=["fit", "predict", "_predict_proba", "decision_function", "score"]
 )
-class SVC(_sklearn_SVC, BaseSVC):
+class SVC(BaseSVC, _sklearn_SVC):
     __doc__ = _sklearn_SVC.__doc__
+    _onedal_factory = onedal_SVC
 
     if sklearn_check_version("1.2"):
         _parameter_constraints: dict = {**_sklearn_SVC._parameter_constraints}
@@ -119,146 +119,6 @@ def fit(self, X, y, sample_weight=None):
 
         return self
 
-    @wrap_output_data
-    def predict(self, X):
-        check_is_fitted(self)
-        return dispatch(
-            self,
-            "predict",
-            {
-                "onedal": self.__class__._onedal_predict,
-                "sklearn": _sklearn_SVC.predict,
-            },
-            X,
-        )
-
-    @wrap_output_data
-    def score(self, X, y, sample_weight=None):
-        check_is_fitted(self)
-        return dispatch(
-            self,
-            "score",
-            {
-                "onedal": self.__class__._onedal_score,
-                "sklearn": _sklearn_SVC.score,
-            },
-            X,
-            y,
-            sample_weight=sample_weight,
-        )
-
-    if sklearn_check_version("1.0"):
-
-        @available_if(_sklearn_SVC._check_proba)
-        def predict_proba(self, X):
-            """
-            Compute probabilities of possible outcomes for samples in X.
-
-            The model need to have probability information computed at training
-            time: fit with attribute `probability` set to True.
-
-            Parameters
-            ----------
-            X : array-like of shape (n_samples, n_features)
-                For kernel="precomputed", the expected shape of X is
-                (n_samples_test, n_samples_train).
-
-            Returns
-            -------
-            T : ndarray of shape (n_samples, n_classes)
-                Returns the probability of the sample for each class in
-                the model. The columns correspond to the classes in sorted
-                order, as they appear in the attribute :term:`classes_`.
-
-            Notes
-            -----
-            The probability model is created using cross validation, so
-            the results can be slightly different than those obtained by
-            predict. Also, it will produce meaningless results on very small
-            datasets.
-            """
-            check_is_fitted(self)
-            return self._predict_proba(X)
-
-        @available_if(_sklearn_SVC._check_proba)
-        def predict_log_proba(self, X):
-            """Compute log probabilities of possible outcomes for samples in X.
-
-            The model need to have probability information computed at training
-            time: fit with attribute `probability` set to True.
-
-            Parameters
-            ----------
-            X : array-like of shape (n_samples, n_features) or \
-                    (n_samples_test, n_samples_train)
-                For kernel="precomputed", the expected shape of X is
-                (n_samples_test, n_samples_train).
-
-            Returns
-            -------
-            T : ndarray of shape (n_samples, n_classes)
-                Returns the log-probabilities of the sample for each class in
-                the model. The columns correspond to the classes in sorted
-                order, as they appear in the attribute :term:`classes_`.
-
-            Notes
-            -----
-            The probability model is created using cross validation, so
-            the results can be slightly different than those obtained by
-            predict. Also, it will produce meaningless results on very small
-            datasets.
-            """
-            xp, _ = get_namespace(X)
-
-            return xp.log(self.predict_proba(X))
-
-    else:
-
-        @property
-        def predict_proba(self):
-            self._check_proba()
-            check_is_fitted(self)
-            return self._predict_proba
-
-        def _predict_log_proba(self, X):
-            xp, _ = get_namespace(X)
-            return xp.log(self.predict_proba(X))
-
-        predict_proba.__doc__ = _sklearn_SVC.predict_proba.__doc__
-
-    @wrap_output_data
-    def _predict_proba(self, X):
-        sklearn_pred_proba = (
-            _sklearn_SVC.predict_proba
-            if sklearn_check_version("1.0")
-            else _sklearn_SVC._predict_proba
-        )
-
-        return dispatch(
-            self,
-            "predict_proba",
-            {
-                "onedal": self.__class__._onedal_predict_proba,
-                "sklearn": sklearn_pred_proba,
-            },
-            X,
-        )
-
-    @wrap_output_data
-    def decision_function(self, X):
-        check_is_fitted(self)
-        return dispatch(
-            self,
-            "decision_function",
-            {
-                "onedal": self.__class__._onedal_decision_function,
-                "sklearn": _sklearn_SVC.decision_function,
-            },
-            X,
-        )
-
-    decision_function.__doc__ = _sklearn_SVC.decision_function.__doc__
-
     def _onedal_gpu_supported(self, method_name, *data):
         class_name = self.__class__.__name__
         patching_status = PatchingConditionsChain(
@@ -322,7 +182,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
             "decision_function_shape": self.decision_function_shape,
         }
 
-        self._onedal_estimator = onedal_SVC(**onedal_params)
+        self._onedal_estimator = self._onedal_factory(**onedal_params)
         self._onedal_estimator.fit(X, y, weights, queue=queue)
 
         if self.probability:
@@ -335,65 +195,4 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
 
         self._save_attributes()
 
-    def _onedal_predict(self, X, queue=None):
-        if sklearn_check_version("1.0"):
-            X = validate_data(
-                self,
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=False,
-                ensure_2d=False,
-                accept_sparse="csr",
-                reset=False,
-            )
-        else:
-            X = check_array(
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=False,
-                accept_sparse="csr",
-            )
-        return self._onedal_estimator.predict(X, queue=queue)
-
-    def _onedal_predict_proba(self, X, queue=None):
-        if getattr(self, "clf_prob", None) is None:
-            raise NotFittedError(
-                "predict_proba is not available when fitted with probability=False"
-            )
-        from .._config import config_context, get_config
-
-        # We use stock metaestimators below, so the only way
-        # to pass a queue is using config_context.
-        cfg = get_config()
-        cfg["target_offload"] = queue
-        with config_context(**cfg):
-            return self.clf_prob.predict_proba(X)
-
-    def _onedal_decision_function(self, X, queue=None):
-        if sklearn_check_version("1.0"):
-            X = validate_data(
-                self,
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=False,
-                accept_sparse="csr",
-                reset=False,
-            )
-        else:
-            X = check_array(
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=False,
-                accept_sparse="csr",
-            )
-        return self._onedal_estimator.decision_function(X, queue=queue)
-
-    def _onedal_score(self, X, y, sample_weight=None, queue=None):
-        return accuracy_score(
-            y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
-        )
-
     fit.__doc__ = _sklearn_SVC.fit.__doc__
-    predict.__doc__ = _sklearn_SVC.predict.__doc__
-    decision_function.__doc__ = _sklearn_SVC.decision_function.__doc__
-    score.__doc__ = _sklearn_SVC.score.__doc__
diff --git a/sklearnex/svm/svr.py b/sklearnex/svm/svr.py
index ff2641bea0..fc99e191a9 100644
--- a/sklearnex/svm/svr.py
+++ b/sklearnex/svm/svr.py
@@ -28,12 +28,13 @@
 if sklearn_check_version("1.6"):
     from sklearn.utils.validation import validate_data
 else:
-    validate_data = BaseSVR._validate_data
+    validate_data = _sklearn_SVR._validate_data
 
 
 @control_n_jobs(decorated_methods=["fit", "predict", "score"])
-class SVR(_sklearn_SVR, BaseSVR):
+class SVR(BaseSVR, _sklearn_SVR):
     __doc__ = _sklearn_SVR.__doc__
+    _onedal_factory = onedal_SVR
 
     if sklearn_check_version("1.2"):
         _parameter_constraints: dict = {**_sklearn_SVR._parameter_constraints}
@@ -96,34 +97,6 @@ def fit(self, X, y, sample_weight=None):
 
         return self
 
-    @wrap_output_data
-    def predict(self, X):
-        check_is_fitted(self)
-        return dispatch(
-            self,
-            "predict",
-            {
-                "onedal": self.__class__._onedal_predict,
-                "sklearn": _sklearn_SVR.predict,
-            },
-            X,
-        )
-
-    @wrap_output_data
-    def score(self, X, y, sample_weight=None):
-        check_is_fitted(self)
-        return dispatch(
-            self,
-            "score",
-            {
-                "onedal": self.__class__._onedal_score,
-                "sklearn": _sklearn_SVR.score,
-            },
-            X,
-            y,
-            sample_weight=sample_weight,
-        )
-
     def _onedal_fit(self, X, y, sample_weight=None, queue=None):
         X, _, sample_weight = self._onedal_fit_checks(X, y, sample_weight)
         onedal_params = {
@@ -139,29 +112,8 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
             "max_iter": self.max_iter,
         }
 
-        self._onedal_estimator = onedal_SVR(**onedal_params)
+        self._onedal_estimator = self._onedal_factory(**onedal_params)
         self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
         self._save_attributes()
 
-    def _onedal_predict(self, X, queue=None):
-        if sklearn_check_version("1.0"):
-            X = validate_data(
-                self,
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=False,
-                accept_sparse="csr",
-                reset=False,
-            )
-        else:
-            X = check_array(
-                X,
-                dtype=[np.float64, np.float32],
-                force_all_finite=False,
-                accept_sparse="csr",
-            )
-        return self._onedal_estimator.predict(X, queue=queue)
-
     fit.__doc__ = _sklearn_SVR.fit.__doc__
-    predict.__doc__ = _sklearn_SVR.predict.__doc__
-    score.__doc__ = _sklearn_SVR.score.__doc__
diff --git a/sklearnex/svm/tests/test_svm.py b/sklearnex/svm/tests/test_svm.py
index f0d561744e..985fc1f44c 100755
--- a/sklearnex/svm/tests/test_svm.py
+++ b/sklearnex/svm/tests/test_svm.py
@@ -16,13 +16,20 @@
 
 import numpy as np
 import pytest
+import scipy.sparse as sp
 from numpy.testing import assert_allclose
+from sklearn.datasets import load_diabetes, load_iris, make_classification
 
+from onedal.svm.tests.test_csr_svm import check_svm_model_equal
 from onedal.tests.utils._dataframes_support import (
     _as_numpy,
     _convert_to_dataframe,
     get_dataframes_and_queues,
 )
+from onedal.tests.utils._device_selection import (
+    get_queues,
+    pass_if_not_implemented_for_gpu,
+)
 
 
 @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
@@ -91,3 +98,77 @@ def test_sklearnex_import_nusvr(dataframe, queue):
         _as_numpy(svc.dual_coef_), [[-1.0, 0.611111, 1.0, -0.611111]], rtol=1e-3
     )
     assert_allclose(_as_numpy(svc.support_), [1, 2, 3, 5])
+
+
+@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented")
+@pytest.mark.parametrize(
+    "queue",
+    get_queues("cpu")
+    + [
+        pytest.param(
+            get_queues("gpu"),
+            marks=pytest.mark.xfail(
+                reason="raises UnknownError for linear and rbf, "
+                "Unimplemented error with inconsistent error message "
+                "for poly and sigmoid"
+            ),
+        )
+    ],
+)
+@pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"])
+def test_binary_dataset(queue, kernel):
+    from sklearnex import config_context
+    from sklearnex.svm import SVC
+
+    X, y = make_classification(n_samples=80, n_features=20, n_classes=2, random_state=0)
+    sparse_X = sp.csr_matrix(X)
+
+    dataset = sparse_X, y, sparse_X
+    with config_context(target_offload=queue):
+        clf0 = SVC(kernel=kernel)
+        clf1 = SVC(kernel=kernel)
+        check_svm_model_equal(queue, clf0, clf1, *dataset)
+
+
+@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented")
+@pytest.mark.parametrize("queue", get_queues())
+@pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"])
+def test_iris(queue, kernel):
+    from sklearnex import config_context
+    from sklearnex.svm import SVC
+
+    if kernel == "rbf":
+        pytest.skip("RBF CSR SVM test failing in 2025.0.")
+    iris = load_iris()
+    rng = np.random.RandomState(0)
+    perm = rng.permutation(iris.target.size)
+    iris.data = iris.data[perm]
+    iris.target = iris.target[perm]
+    sparse_iris_data = sp.csr_matrix(iris.data)
+
+    dataset = sparse_iris_data, iris.target, sparse_iris_data
+
+    with config_context(target_offload=queue):
+        clf0 = SVC(kernel=kernel)
+        clf1 = SVC(kernel=kernel)
+        check_svm_model_equal(queue, clf0, clf1, *dataset, decimal=2)
+
+
+@pass_if_not_implemented_for_gpu(reason="csr svm is not implemented")
+@pytest.mark.parametrize("queue", get_queues())
+@pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"])
+def test_diabetes(queue, kernel):
+    from sklearnex import config_context
+    from sklearnex.svm import SVR
+
+    if kernel == "sigmoid":
+        pytest.skip("Sparse sigmoid kernel function is buggy.")
+    diabetes = load_diabetes()
+
+    sparse_diabetes_data = sp.csr_matrix(diabetes.data)
+    dataset = sparse_diabetes_data, diabetes.target, sparse_diabetes_data
+
+    with config_context(target_offload=queue):
+        clf0 = SVR(kernel=kernel, C=0.1)
+        clf1 = SVR(kernel=kernel, C=0.1)
+        check_svm_model_equal(queue, clf0, clf1, *dataset)