Skip to content

Commit

Permalink
_intelligent_sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-vignal committed Jun 18, 2024
1 parent f5d57e8 commit 058b0ed
Showing 1 changed file with 43 additions and 31 deletions.
74 changes: 43 additions & 31 deletions contribution_plot_improvment.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,51 +59,64 @@ Jittering enhances the clarity of violin plots by dispersing points and making i
- **Clear Class Differentiation**: Facilitates understanding of class-specific contributions.
- **Visual Appeal**: Reduces clutter, enhancing aesthetic appeal of the plot.

## 2. Smart Selection for Diverse Class Representation
## 2. Smart Sampling for Diverse Class Representation

Shapash utilizes a **smart sampling strategy** to ensure a balanced representation of classes within the dataset. This approach involves clustering data points and sampling from each cluster, thereby avoiding biases towards specific classes and ensuring the selected points reflect the overall data distribution.

Here's the function handling smart selection:

```python
def _subset_sampling(
self, selection=None, max_points=2000, col=None, col_value_count=0
):
if col_value_count > 10:
from sklearn.cluster import MiniBatchKMeans

# Clustering data using MiniBatchKMeans
kmeans = MiniBatchKMeans(n_clusters=10, random_state=0)
kmeans.fit(data[[col]] if col else data)
data["group"] = kmeans.predict(data)
def _intelligent_sampling(self, data, max_points, col_value_count, random_seed):
"""
Performs intelligent sampling based on the distribution of values in the specified column.
"""
rng = np.random.default_rng(seed=random_seed)
is_col_str = True

# Check if data is numerical data
if data.dtype.kind in "fc":
is_col_str = False

if (col_value_count < len(data) / 20) or is_col_str:
cluster_labels = data
cluster_counts = cluster_labels.value_counts()
else:
# Grouping data based on index or column value
data["group"] = (
data.index % 10 if col is None else data[col].apply(lambda x: int(x % 10))
)

idx_list = []
for group in data["group"].unique():
data_group = data[data["group"] == group]
sample_size = min(len(data_group), max_points // 10)
idx_list += data_group.sample(n=sample_size, random_state=0).index.to_list()
return idx_list
n_clusters = min(100, len(data) // 20)
kmeans = KMeans(n_clusters=n_clusters, random_state=random_seed, n_init="auto")
cluster_labels = pd.Series(kmeans.fit_predict(data.values.reshape(-1, 1)))
cluster_counts = cluster_labels.value_counts()

weights = cluster_counts.apply(lambda x: (x ** 0.5) / x).to_dict()
selection_weights = cluster_labels.apply(lambda x: weights[x])
selection_weights /= selection_weights.sum()
selected_indices = rng.choice(
data.index.tolist(), max_points, p=selection_weights, replace=False
)
return selected_indices
```

### How It Works

The smart selection process begins by evaluating the **number of unique values** in a specified column (`col_value_count`). If this number is greater than 10, the data is clustered using the **MiniBatchKMeans** algorithm from the `sklearn` library. The algorithm creates 10 clusters, and each data point is assigned to one of these clusters.
The `_intelligent_sampling` function selects a subset of data based on the distribution of values in a specified column. Here’s how it operates:

If the number of unique values is 10 or fewer, a simpler approach is used: data points are grouped based on their index or a specific column value.
1. **Data Type Handling**:
- It checks if the column (`col`) contains numerical (`float` or `int`) or categorical (`object` or `category`) data.

1. **Clustering with MiniBatchKMeans**:
- If there are more than 10 unique values, `MiniBatchKMeans` clusters the data into 10 groups.
- Each data point is assigned a cluster label stored in the "group" column.
2. **Condition Check**:
- If the number of unique values (`col_value_count`) is less than 5% of the total rows in the dataset (`len(data) / 20`) or if the column contains string data, it uses the original column values without clustering (`is_col_str`).

2. **Grouping without Clustering**:
- If there are 10 or fewer unique values, data points are assigned to groups based on their index or a specific column value.
3. **Clustering Approach**:
- **Numeric Data**: For numeric columns with more than 5% unique values, the function performs KMeans clustering. It determines the number of clusters (`n_clusters`) based on either 100 clusters or a fraction of the dataset size (`len(data[col]) // 20`).
- **Categorical Data**: No clustering is applied to categorical data; it directly uses the original values.

After grouping, the function samples points from each group to ensure that the final selection is diverse and representative of the entire dataset.
4. **Cluster Weight Calculation**:
- If clustering is applied, weights for each cluster are calculated based on the square root of cluster counts (`(x ** 0.5) / x`). This ensures a balanced representation of clusters in the sampling process.

5. **Selection Process**:
- The function normalizes the calculated weights (`selection_weights`) so that they sum to 1, ensuring proportional selection probabilities.

6. **Random Selection**:
- Using a random number generator (`rng.choice`), the function selects `max_points` indices from the dataset based on the normalized weights (`selection_weights`). This strategy ensures the selected subset reflects the original data’s distribution.

### Summary

Expand Down Expand Up @@ -157,7 +170,6 @@ density_plot = go.Scatter(
showlegend=False,
line={"color": self._style_dict["contrib_distribution"]},
)
fig.add_trace(density_plot)
```

### How It Works
Expand Down

0 comments on commit 058b0ed

Please sign in to comment.