Skip to content

Commit

Permalink
convert case_sensitive flag to convert_lowercase
Browse files Browse the repository at this point in the history
  • Loading branch information
suzannejin committed Jan 17, 2025
1 parent a44c924 commit 6ee20d9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 28 deletions.
13 changes: 7 additions & 6 deletions src/stimulus/data/encoding/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class TextOneHotEncoder(AbstractEncoder):
Attributes:
alphabet (str): the alphabet to one hot encode the data with.
case_sensitive (bool): whether the encoder is case sensitive or not. Default = False
convert_lowercase (bool): whether the encoder would convert the sequence (and alphabet) to lowercase
or not. Default = False
padding (bool): whether to pad the sequences with zero or not. Default = False
encoder (OneHotEncoder): preprocessing.OneHotEncoder object initialized with self.alphabet
Expand All @@ -101,7 +102,7 @@ class TextOneHotEncoder(AbstractEncoder):
_sequence_to_array: transforms a sequence into a numpy array
"""

def __init__(self, alphabet: str = "acgt", case_sensitive: bool = False, padding: bool = False) -> None:
def __init__(self, alphabet: str = "acgt", convert_lowercase: bool = False, padding: bool = False) -> None:
"""Initialize the TextOneHotEncoder class.
Args:
Expand All @@ -115,11 +116,11 @@ def __init__(self, alphabet: str = "acgt", case_sensitive: bool = False, padding
logger.error(error_msg)
raise ValueError(error_msg)

if not case_sensitive:
if convert_lowercase:
alphabet = alphabet.lower()

self.alphabet = alphabet
self.case_sensitive = case_sensitive
self.convert_lowercase = convert_lowercase
self.padding = padding

self.encoder = preprocessing.OneHotEncoder(
Expand Down Expand Up @@ -150,7 +151,7 @@ def _sequence_to_array(self, sequence: str) -> np.array:
logger.error(error_msg)
raise ValueError(error_msg)

if not self.case_sensitive:
if self.convert_lowercase:
sequence = sequence.lower()

sequence_array = np.array(list(sequence))
Expand Down Expand Up @@ -186,7 +187,7 @@ def encode(self, data: str) -> torch.Tensor:
[0, 0, 0, 1],
[0, 0, 0, 0]])
>>> encoder = TextOneHotEncoder(alphabet="ACgt", case_sensitive=True)
>>> encoder = TextOneHotEncoder(alphabet="ACgt")
>>> encoder.encode("acgt")
tensor([[0, 0, 0, 0],
[0, 0, 0, 0],
Expand Down
42 changes: 20 additions & 22 deletions tests/data/encoding/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ class TestTextOneHotEncoder:
@staticmethod
@pytest.fixture
def encoder_default():
"""Provides a default encoder with lowercase alphabet 'acgt' (not case-sensitive)."""
return TextOneHotEncoder(alphabet="acgt", case_sensitive=False, padding=True)
"""Provides a default encoder."""
return TextOneHotEncoder(alphabet="acgt", padding=True)

@staticmethod
@pytest.fixture
def encoder_case_sensitive():
"""Provides an encoder with mixed case alphabet 'ACgt' (case-sensitive)."""
return TextOneHotEncoder(alphabet="ACgt", case_sensitive=True, padding=True)
def encoder_lowercase():
"""Provides an encoder with convert_lowercase set to True."""
return TextOneHotEncoder(alphabet="ACgt", convert_lowercase=True, padding=True)

# ---- Test for initialization ---- #

Expand All @@ -36,7 +36,7 @@ def test_init_with_non_string_alphabet_raises_value_error(self):
def test_init_with_string_alphabet(self):
encoder = TextOneHotEncoder(alphabet="acgt")
assert encoder.alphabet == "acgt"
assert encoder.case_sensitive is False
assert encoder.convert_lowercase is False
assert encoder.padding is False

# ---- Tests for _sequence_to_array ---- #
Expand All @@ -54,17 +54,15 @@ def test_sequence_to_array_returns_correct_shape(self, encoder_default):
# check content
assert (arr.flatten() == list(seq)).all()

def test_sequence_to_array_is_case_insensitive(self, encoder_default):
def test_sequence_to_array_is_case_sensitive(self, encoder_default):
seq = "AcGT"
arr = encoder_default._sequence_to_array(seq)
# Since encoder_default is not case sensitive, sequence is lowercased internally.
assert (arr.flatten() == list("acgt")).all()
assert (arr.flatten() == list("AcGT")).all()

def test_sequence_to_array_is_case_sensitive(self, encoder_case_sensitive):
def test_sequence_to_array_is_lowercase(self, encoder_lowercase):
seq = "AcGT"
arr = encoder_case_sensitive._sequence_to_array(seq)
# With case_sensitive=True, we do not modify 'AcGT'
assert (arr.flatten() == list("AcGT")).all()
arr = encoder_lowercase._sequence_to_array(seq)
assert (arr.flatten() == list("acgt")).all()

# ---- Tests for encode ---- #

Expand All @@ -81,25 +79,25 @@ def test_encode_unknown_character_returns_zero_vector(self, encoder_default):
# the last character 'n' is not in 'acgt', so the last row should be all zeros
assert torch.all(encoded[-1] == 0)

def test_encode_case_sensitivity_true(self, encoder_case_sensitive):
def test_encode_default(self, encoder_default):
"""Case-sensitive: 'ACgt' => 'ACgt' means 'A' and 'C' are uppercase in the alphabet,
'g' and 't' are lowercase in the alphabet.
"""
seq = "ACgt"
encoded = encoder_case_sensitive.encode(seq)
encoded = encoder_default.encode(seq)
# shape = (len(seq), 4)
assert encoded.shape == (4, 4)
# 'A' should be one-hot at the 0th index, 'C' at the 1st index, 'g' at the 2nd, 't' at the 3rd.
# The order of categories in OneHotEncoder is typically ['A', 'C', 'g', 't'] given we passed ['A','C','g','t'].
assert torch.all(encoded[0] == torch.tensor([1, 0, 0, 0])) # 'A'
assert torch.all(encoded[1] == torch.tensor([0, 1, 0, 0])) # 'C'
assert torch.all(encoded[0] == torch.tensor([0, 0, 0, 0])) # 'A'
assert torch.all(encoded[1] == torch.tensor([0, 0, 0, 0])) # 'C'
assert torch.all(encoded[2] == torch.tensor([0, 0, 1, 0])) # 'g'
assert torch.all(encoded[3] == torch.tensor([0, 0, 0, 1])) # 't'

def test_encode_case_sensitivity_false(self, encoder_default):
"""Case-insensitive: 'ACGT' => 'acgt' internally."""
seq = "ACGT"
encoded = encoder_default.encode(seq)
def test_encode_lowercase(self, encoder_lowercase):
"""Case-insensitive: 'ACgt' => 'acgt' internally."""
seq = "ACgt"
encoded = encoder_lowercase.encode(seq)
# shape = (4,4)
assert encoded.shape == (4, 4)
# The order of categories in OneHotEncoder is typically ['a', 'c', 'g', 't'] for the default encoder.
Expand Down Expand Up @@ -127,7 +125,7 @@ def test_encode_all_with_list_of_sequences(self, encoder_default):
assert torch.all(encoded[1] == encoder_default.encode(seqs[1]))

def test_encode_all_with_padding_false(self):
encoder = TextOneHotEncoder(alphabet="acgt", case_sensitive=False, padding=False)
encoder = TextOneHotEncoder(alphabet="acgt", padding=False)
seqs = ["acgt", "acgtn"] # different lengths
# should raise ValueError because lengths differ
with pytest.raises(ValueError) as excinfo:
Expand Down

0 comments on commit 6ee20d9

Please sign in to comment.