Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: The behavior of kmeans in SPU does not match kmeans in sklearns #536

Closed
winnylyc opened this issue Feb 2, 2024 · 15 comments
Closed
Assignees

Comments

@winnylyc
Copy link
Contributor

winnylyc commented Feb 2, 2024

Issue Type

Currentness/Accuracy

Modules Involved

Others

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu 0.7.0b0

OS Platform and Distribution

Linux Ubuntu 22.04

Python Version

3.10

Compiler Version

No response

Current Behavior?

Hello, sorry for disturbing you.

When I use Kmeans implemented in sml/cluster, I encounter the problem that the behavior of it does not match kmeans implemented in sklearns with some inputs. I will include the code below.

Standalone code to reproduce the issue

sim = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)
def proc(x):
    model = KMEANS(n_clusters=3, n_samples=x.shape[0], max_iter=10)
    model.fit(x)
    return model._centers
X = jnp.array([[-4, -3, -2, -1]]).T
result = spsim.sim_jax(sim, proc)(X)
print("result\n", result)

# Compare with sklearn
from sklearn.cluster import KMeans

model = KMeans(n_clusters=3)
model.fit(X)
print("sklearn:\n", model.cluster_centers_)

Relevant log output

result
 [[-1.4999962]
 [ 0.       ]
 [-3.4999924]]
sklearn:
 [[-1. ]
 [-3.5]
 [-2. ]]
@deadlywing
Copy link
Contributor

Hello,,thanks for reporting this.

The reason behind this is that sklearn runs the clustering algorithm several times(with different init centers) and use the best one.

However,sml just runs the procedure once for efficiency ,so it may get different outputs.

BTW,It's pretty straight to support this functionality,would you mind doing this job?

@winnylyc
Copy link
Contributor Author

winnylyc commented Feb 3, 2024

Thanks for your response!
I will try solving this problem.

@winnylyc
Copy link
Contributor Author

winnylyc commented Feb 4, 2024

I observed that another factor behind this difference is that default method to generate initial centers is different. I will also try adding a new initialization method.

@deadlywing
Copy link
Contributor

I observed that another factor behind this difference is that default method to generate initial centers is different. I will also try adding a new initialization method.

Exactly, sklearn use kmeans++ to decide init centroids, if you want to add this method, you MUST generate some random values before running in SPU.

@winnylyc
Copy link
Contributor Author

winnylyc commented Feb 4, 2024

Sorry, I don't catch the point. You mean I cannot use functions from jax.random in my implementation of kmeans++?

@deadlywing
Copy link
Contributor

Sorry, I don't catch the point. You mean I cannot use functions from jax.random in my implementation of kmeans++?

you can't generate random values in SPU runtime, you can refer to #80 .
But you can use jax.random.xxx in init, just like

self.init_params = jax.random.randint(
jax.random.PRNGKey(1), shape=[self.n_clusters], minval=0, maxval=n_samples
)

@winnylyc
Copy link
Contributor Author

winnylyc commented Feb 4, 2024

I think I got your point. You mean that I need to generate all the random values I need in init function, and using jax.random.xxx in all the other functions will cause unexpected behavior?
So init function does run in SPU runtime?

@deadlywing
Copy link
Contributor

I think I got your point. You mean that I need to generate all the random values I need in init function, and using jax.random.xxx in all the other functions will cause unexpected behavior?

there are two cases in SML:

  1. Generate these in __init__ is OK only If the random values can be public (all paries can see these in plaintext)
  2. If the random values are sensitive, then you can feed these values in fit as input parameters, then other parties can't see these values.

@winnylyc
Copy link
Contributor Author

winnylyc commented Feb 4, 2024

So values of all the attributes in init function are public, right?
In addition, I think all the random values in kmeans++ can be public, so I will choose to generate them in init function. Is that appropriate?

@deadlywing
Copy link
Contributor

So values of all the attributes in init function are public, right?

Yes.

In addition, I think all the random values in kmeans++ can be public, so I will choose to generate them in init function. Is that appropriate?

I think so.

@winnylyc
Copy link
Contributor Author

winnylyc commented Feb 4, 2024

Thanks for your reminder! It really helps a lot.

@winnylyc
Copy link
Contributor Author

winnylyc commented Feb 4, 2024

When I implemented the selection of the best centers, I encountered a behavior looks like a bug.
Here is the code.

sim = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)
def proc(x, centers):
    C = x.reshape((1, 1, x.shape[0], x.shape[1])) - centers.reshape(
        (centers.shape[0], centers.shape[1], 1, centers.shape[2])
    ) # shape [n_init, n_clusters, n_samples, n_features]
    distance =  jnp.sum(jnp.square(C), axis=3) # shape [n_init, n_clusters, n_samples]
    index = jnp.argmin(distance, axis=1) # shape [n_init, n_samples]
    inertia = jnp.sum(jnp.take(distance, index), axis=1) # shape [n_init]
    centers_best = centers[jnp.argmin(inertia)] # selection here, should choose centers with shape [n_clusters, n_features] from [n_init, n_clusters, n_features]
    return centers_best

X = jnp.ones((4, 2)) # shape [n_samples, n_features]
centers = jnp.ones((2, 3, 2)) # shape [n_init, n_clusters, n_features]

result = spsim.sim_jax(sim, proc)(X, centers)
print("result\n", result)

Here is the sample with n_init = 2, and it raises the following runtime error.

Traceback (most recent call last):
  File "smalltest.py", line 23, in <module>
    result = spsim.sim_jax(sim, proc)(X, centers)
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/utils/simulation.py", line 169, in wrapper
    out_flat = sim(executable, *args_flat)
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/utils/simulation.py", line 117, in __call__
    parties = [job.join() for job in jobs]
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/utils/simulation.py", line 117, in <listcomp>
    parties = [job.join() for job in jobs]
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/utils/simulation.py", line 47, in join
    raise self.exc
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/utils/simulation.py", line 40, in run
    self.ret = self._target(*self._args, **self._kwargs)
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/utils/simulation.py", line 106, in wrapper
    rt.run(executable)
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/api.py", line 44, in run
    return self._vm.Run(executable.SerializeToString())
RuntimeError: what: 
        [Enforce fail at libspu/core/ndarray_ref.cc:300] (end_indices[idx] <= shape()[idx]). Slice end at axis 0 = 3 is larger than input shape 2
Stacktrace:
#0 spu::kernel::hal::slice()+0x7fff4b00d93e
#1 spu::kernel::hlo::SecretDynamicSlice()+0x7fff4af236f2
#2 spu::kernel::hlo::DynamicSlice()+0x7fff4af24540
#3 spu::device::pphlo::dispatchOp<>()+0x7fff4a91d3fa
#4 spu::device::pphlo::dispatchOp<>()+0x7fff4a91f0e4
#5 spu::device::pphlo::dispatchOp<>()+0x7fff4a920f93
#6 spu::device::pphlo::dispatchOp<>()+0x7fff4a922fe9
#7 spu::device::pphlo::dispatchOp<>()+0x7fff4a9252f0
#8 spu::device::runBlock()+0x7fff4aa60f05
#9 spu::device::runRegion()+0x7fff4aa62f93
#10 spu::device::executeImpl()+0x7fff4a4497f0
#11 spu::RuntimeWrapper::Run()+0x7fff49c2e10c
#12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x7fff49c025b5
#13 pybind11::cpp_function::dispatcher()+0x7fff49bf785e
#14 cfunction_call_varargs+0x55555569000e

stacktrace: 
#0 spu::kernel::hal::slice()+0x7fff4b00d93e
#1 spu::kernel::hlo::SecretDynamicSlice()+0x7fff4af236f2
#2 spu::kernel::hlo::DynamicSlice()+0x7fff4af24540
#3 spu::device::pphlo::dispatchOp<>()+0x7fff4a91d3fa
#4 spu::device::pphlo::dispatchOp<>()+0x7fff4a91f0e4
#5 spu::device::pphlo::dispatchOp<>()+0x7fff4a920f93
#6 spu::device::pphlo::dispatchOp<>()+0x7fff4a922fe9
#7 spu::device::pphlo::dispatchOp<>()+0x7fff4a9252f0
#8 spu::device::runBlock()+0x7fff4aa60f05
#9 spu::device::runRegion()+0x7fff4aa62f93
#10 spu::device::executeImpl()+0x7fff4a4497f0
#11 spu::RuntimeWrapper::Run()+0x7fff49c2e10c
#12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x7fff49c025b5
#13 pybind11::cpp_function::dispatcher()+0x7fff49bf785e
#14 cfunction_call_varargs+0x55555569000e

The strange thing is that only when first dimension (n_init) of centers is 2 will there be a runtime error. All the other settings of n_init will execute as expected. Is this a bug or an expected behavior?

@tpppppub
Copy link
Collaborator

tpppppub commented Feb 4, 2024

When I implemented the selection of the best centers, I encountered a behavior looks like a bug. Here is the code.

sim = spsim.Simulator.simple(
    3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)
def proc(x, centers):
    C = x.reshape((1, 1, x.shape[0], x.shape[1])) - centers.reshape(
        (centers.shape[0], centers.shape[1], 1, centers.shape[2])
    ) # shape [n_init, n_clusters, n_samples, n_features]
    distance =  jnp.sum(jnp.square(C), axis=3) # shape [n_init, n_clusters, n_samples]
    index = jnp.argmin(distance, axis=1) # shape [n_init, n_samples]
    inertia = jnp.sum(jnp.take(distance, index), axis=1) # shape [n_init]
    centers_best = centers[jnp.argmin(inertia)] # selection here, should choose centers with shape [n_clusters, n_features] from [n_init, n_clusters, n_features]
    return centers_best

X = jnp.ones((4, 2)) # shape [n_samples, n_features]
centers = jnp.ones((2, 3, 2)) # shape [n_init, n_clusters, n_features]

result = spsim.sim_jax(sim, proc)(X, centers)
print("result\n", result)

Here is the sample with n_init = 2, and it raises the following runtime error.

Traceback (most recent call last):
  File "smalltest.py", line 23, in <module>
    result = spsim.sim_jax(sim, proc)(X, centers)
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/utils/simulation.py", line 169, in wrapper
    out_flat = sim(executable, *args_flat)
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/utils/simulation.py", line 117, in __call__
    parties = [job.join() for job in jobs]
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/utils/simulation.py", line 117, in <listcomp>
    parties = [job.join() for job in jobs]
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/utils/simulation.py", line 47, in join
    raise self.exc
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/utils/simulation.py", line 40, in run
    self.ret = self._target(*self._args, **self._kwargs)
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/utils/simulation.py", line 106, in wrapper
    rt.run(executable)
  File "/home/ylipf/anaconda3/envs/spu/lib/python3.8/site-packages/spu/api.py", line 44, in run
    return self._vm.Run(executable.SerializeToString())
RuntimeError: what: 
        [Enforce fail at libspu/core/ndarray_ref.cc:300] (end_indices[idx] <= shape()[idx]). Slice end at axis 0 = 3 is larger than input shape 2
Stacktrace:
#0 spu::kernel::hal::slice()+0x7fff4b00d93e
#1 spu::kernel::hlo::SecretDynamicSlice()+0x7fff4af236f2
#2 spu::kernel::hlo::DynamicSlice()+0x7fff4af24540
#3 spu::device::pphlo::dispatchOp<>()+0x7fff4a91d3fa
#4 spu::device::pphlo::dispatchOp<>()+0x7fff4a91f0e4
#5 spu::device::pphlo::dispatchOp<>()+0x7fff4a920f93
#6 spu::device::pphlo::dispatchOp<>()+0x7fff4a922fe9
#7 spu::device::pphlo::dispatchOp<>()+0x7fff4a9252f0
#8 spu::device::runBlock()+0x7fff4aa60f05
#9 spu::device::runRegion()+0x7fff4aa62f93
#10 spu::device::executeImpl()+0x7fff4a4497f0
#11 spu::RuntimeWrapper::Run()+0x7fff49c2e10c
#12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x7fff49c025b5
#13 pybind11::cpp_function::dispatcher()+0x7fff49bf785e
#14 cfunction_call_varargs+0x55555569000e

stacktrace: 
#0 spu::kernel::hal::slice()+0x7fff4b00d93e
#1 spu::kernel::hlo::SecretDynamicSlice()+0x7fff4af236f2
#2 spu::kernel::hlo::DynamicSlice()+0x7fff4af24540
#3 spu::device::pphlo::dispatchOp<>()+0x7fff4a91d3fa
#4 spu::device::pphlo::dispatchOp<>()+0x7fff4a91f0e4
#5 spu::device::pphlo::dispatchOp<>()+0x7fff4a920f93
#6 spu::device::pphlo::dispatchOp<>()+0x7fff4a922fe9
#7 spu::device::pphlo::dispatchOp<>()+0x7fff4a9252f0
#8 spu::device::runBlock()+0x7fff4aa60f05
#9 spu::device::runRegion()+0x7fff4aa62f93
#10 spu::device::executeImpl()+0x7fff4a4497f0
#11 spu::RuntimeWrapper::Run()+0x7fff49c2e10c
#12 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()+0x7fff49c025b5
#13 pybind11::cpp_function::dispatcher()+0x7fff49bf785e
#14 cfunction_call_varargs+0x55555569000e

The strange thing is that only when first dimension (n_init) of centers is 2 will there be a runtime error. All the other settings of n_init will execute as expected. Is this a bug or an expected behavior?

Please try with the latest main branch code. This bug is fixed by the PR #532 .

@winnylyc
Copy link
Contributor Author

winnylyc commented Feb 4, 2024

Sorry for not testing with latest commit and disturbing you. The code can execute as expected in latest commit.

@winnylyc
Copy link
Contributor Author

Solved with #546.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants