From 571239522a386b7ac9be914ed7fd8459fceade38 Mon Sep 17 00:00:00 2001
From: Sarah Yurick <sarahyurick@gmail.com>
Date: Mon, 23 Dec 2024 13:19:06 -0800
Subject: [PATCH 1/2] first commit

Signed-off-by: Sarah Yurick <sarahyurick@gmail.com>
---
 nemo_curator/classifiers/base.py | 32 +++++++++++---------------------
 1 file changed, 11 insertions(+), 21 deletions(-)

diff --git a/nemo_curator/classifiers/base.py b/nemo_curator/classifiers/base.py
index 699d034de..0c4bf70e0 100644
--- a/nemo_curator/classifiers/base.py
+++ b/nemo_curator/classifiers/base.py
@@ -121,11 +121,14 @@ def _run_classifier_helper(
     prob_col: str = None,
 ) -> "dask_cudf.DataFrame":
 
-    keep_prob = prob_col is not None
-    prob_internal_col = "_prob"
-    # TODO: Make crossfit handle this cleanly
-    pred_internal_col = "labels"
+    if prob_col:
+        keep_prob = True
+    else:
+        keep_prob = False
+        prob_col = "_prob"
+
     df["sliced_text"] = df[text_field].str.slice(0, max_chars)
+
     columns_to_keep_list = df.columns.to_list()
     columns_to_keep_list.remove("sliced_text")
 
@@ -135,29 +138,16 @@ def _run_classifier_helper(
             model,
             sorted_data_loader=True,
             batch_size=batch_size,
-            pred_output_col=prob_internal_col,
+            model_output_cols=[label_col],
+            pred_output_col=prob_col,
         ),
         repartition=df.npartitions,
         keep_cols=columns_to_keep_list,
     )
     df = classifier_pipe(df)
 
-    # TODO: Make crossfit handle this cleanly
-    # to prevent the labeler from dropping the prob_internal_col
-    # and combine it into a single step
-    labeling_pipe = op.Sequential(
-        op.Labeler(labels, cols=[prob_internal_col]),
-        keep_cols=columns_to_keep_list + [prob_internal_col],
-    )
-    df = labeling_pipe(df)
-
-    if keep_prob:
-        df = df.rename(
-            columns={prob_internal_col: prob_col, pred_internal_col: label_col},
-        )
-    else:
-        df = df.rename(columns={pred_internal_col: label_col})
-        df = df.drop(columns=[prob_internal_col])
+    if not keep_prob:
+        df = df.drop(columns=[prob_col])
 
     return df
 

From 498910d5cb8175dfd8fcdd0a9494ee7ba6d18dbf Mon Sep 17 00:00:00 2001
From: Sarah Yurick <sarahyurick@gmail.com>
Date: Tue, 24 Dec 2024 10:29:30 -0800
Subject: [PATCH 2/2] working code and pytest

Signed-off-by: Sarah Yurick <sarahyurick@gmail.com>
---
 nemo_curator/classifiers/base.py | 15 ++++--------
 tests/test_classifiers.py        | 41 +++++++++++++++++++-------------
 2 files changed, 30 insertions(+), 26 deletions(-)

diff --git a/nemo_curator/classifiers/base.py b/nemo_curator/classifiers/base.py
index 0c4bf70e0..4f8cdc253 100644
--- a/nemo_curator/classifiers/base.py
+++ b/nemo_curator/classifiers/base.py
@@ -122,33 +122,28 @@ def _run_classifier_helper(
 ) -> "dask_cudf.DataFrame":
 
     if prob_col:
-        keep_prob = True
+        df[prob_col] = 0
     else:
-        keep_prob = False
         prob_col = "_prob"
 
-    df["sliced_text"] = df[text_field].str.slice(0, max_chars)
-
     columns_to_keep_list = df.columns.to_list()
-    columns_to_keep_list.remove("sliced_text")
 
     classifier_pipe = op.Sequential(
-        op.Tokenizer(model, cols=["sliced_text"], tokenizer_type="default"),
+        op.Tokenizer(
+            model, cols=[text_field], tokenizer_type="default", max_chars=max_chars
+        ),
         op.Predictor(
             model,
             sorted_data_loader=True,
             batch_size=batch_size,
-            model_output_cols=[label_col],
             pred_output_col=prob_col,
         ),
+        op.Labeler(labels, cols=[prob_col], suffix=label_col),
         repartition=df.npartitions,
         keep_cols=columns_to_keep_list,
     )
     df = classifier_pipe(df)
 
-    if not keep_prob:
-        df = df.drop(columns=[prob_col])
-
     return df
 
 
diff --git a/tests/test_classifiers.py b/tests/test_classifiers.py
index da427689b..5d681089f 100644
--- a/tests/test_classifiers.py
+++ b/tests/test_classifiers.py
@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
-
 import pytest
 from distributed import Client
 
@@ -48,24 +46,35 @@ def domain_dataset():
 
 
 @pytest.mark.gpu
-def test_domain_classifier(gpu_client, domain_dataset):
+@pytest.mark.parametrize("keep_prob", [True, False])
+def test_domain_classifier(gpu_client, domain_dataset, keep_prob):
     from nemo_curator.classifiers import DomainClassifier
 
-    classifier = DomainClassifier()
-    result_dataset = classifier(dataset=domain_dataset)
-    result_pred = result_dataset.df.compute()["domain_pred"]
+    if keep_prob:
+        prob_column = "domain_prob"
+    else:
+        prob_column = None
 
-    expected_pred = cudf.Series(
-        [
-            "Computers_and_Electronics",
-            "Finance",
-            "Health",
-            "Jobs_and_Education",
-            "Travel_and_Transportation",
-        ]
-    )
+    classifier = DomainClassifier(prob_column=prob_column)
+    result_dataset = classifier(dataset=domain_dataset)
 
-    assert result_pred.equals(expected_pred)
+    if keep_prob:
+        result_df = result_dataset.df.compute()
+        assert "domain_prob" in result_df.columns
+    else:
+        result_pred = result_dataset.df.compute()["domain_pred"]
+
+        expected_pred = cudf.Series(
+            [
+                "Computers_and_Electronics",
+                "Finance",
+                "Health",
+                "Jobs_and_Education",
+                "Travel_and_Transportation",
+            ]
+        )
+
+        assert result_pred.equals(expected_pred)
 
 
 @pytest.mark.gpu