diff --git a/.gitignore b/.gitignore index cb3c3c8..f5130f0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,7 @@ *.pyc .project .pydevproject -__pycache__/ \ No newline at end of file +__pycache__/ +*.egg +*.egg-info/ +build/ \ No newline at end of file diff --git a/__init__.py b/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/highdimensional_boundary/__init__.py b/highdimensional_boundary/__init__.py new file mode 100644 index 0000000..3d4e40c --- /dev/null +++ b/highdimensional_boundary/__init__.py @@ -0,0 +1,2 @@ +from .decisionboundaryplot import DBPlot +__all__ = ['DBPlot',] \ No newline at end of file diff --git a/decisionboundaryplot.py b/highdimensional_boundary/decisionboundaryplot.py similarity index 98% rename from decisionboundaryplot.py rename to highdimensional_boundary/decisionboundaryplot.py index f9291ca..f259ca6 100644 --- a/decisionboundaryplot.py +++ b/highdimensional_boundary/decisionboundaryplot.py @@ -9,7 +9,7 @@ import nlopt import random from scipy.spatial.distance import euclidean, squareform, pdist -from utils import minimum_spanning_tree, polar_to_cartesian +from .utils import minimum_spanning_tree, polar_to_cartesian from sklearn.model_selection import GridSearchCV from sklearn.svm import SVC from sklearn.metrics import accuracy_score, f1_score @@ -270,6 +270,7 @@ def fit(self, X, y, training_indices=None): print( "Failed to find initial decision boundary. Retrying... If this keeps happening, increasing the acceptance threshold might help. Also, make sure the classifier is able to find a point with 0.5 prediction probability (usually requires an even number of estimators/neighbors/etc)." ) + self.acceptance_threshold += 0.03 return self.fit(X, y, training_indices) # step 3. look for decision boundary points between already known db @@ -376,6 +377,7 @@ def plot( background_resolution=100, scatter_size_scale=1.0, legend=True, + annotate=None, ): """Plots the dataset and the identified decision boundary in 2D. @@ -408,6 +410,8 @@ def plot( legend : boolean, optional (default=False) Whether to display a legend + annotate : float, optional (default=None) + If not None, it specifies the percentage of points to annotate as a float: range from 0 to 1.0 Returns ------- plt : The matplotlib.pyplot or axis object which has been passed in, after @@ -493,14 +497,10 @@ def plot( ) # label data points with their indices - for i in range(len(self.X2d)): - plt.text( - self.X2d[i, 0] + (self.X2d_xmax - self.X2d_xmin) * 0.5e-2, - self.X2d[i, 1] + (self.X2d_ymax - self.X2d_ymin) * 0.5e-2, - str(i), - size=8, - ) - + if annotate: + step = int(len(self.X2d)*annotate) + 1 + for i, txt in enumerate(range(self.X2d.shape[0])[:step:]): + plt.text(self.X2d[i, 0], self.X2d[i, 1], txt, fontsize='xx-small') if legend: plt.legend( [ diff --git a/demo.py b/highdimensional_boundary/demo.py similarity index 100% rename from demo.py rename to highdimensional_boundary/demo.py diff --git a/uci_loader.py b/highdimensional_boundary/uci_loader.py similarity index 100% rename from uci_loader.py rename to highdimensional_boundary/uci_loader.py diff --git a/utils.py b/highdimensional_boundary/utils.py similarity index 100% rename from utils.py rename to highdimensional_boundary/utils.py diff --git a/requirements.txt b/requirements.txt index b4e720d..6ed70e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -matplotlib==3.0.2 -scipy==1.1.0 -numpy==1.15.4 +matplotlib +scipy +numpy nlopt==2.6.1 -scikit_learn==0.21.2 +scikit_learn diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..b80f002 --- /dev/null +++ b/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup, find_packages +from os.path import dirname, join, realpath +from textwrap import dedent + +PROJECT_ROOT = dirname(realpath(__file__)) +REQUIREMENTS_FILE = join(PROJECT_ROOT, "requirements.txt") + +with open(REQUIREMENTS_FILE) as f: + install_reqs = f.read().splitlines() + +install_reqs.append("setuptools") + +setup( + name='highdimensional_boundary', + version='1.0.0', + author='Tamas Madl', + packages=find_packages(), + install_requires=install_reqs, +)