Skip to content

Commit

Permalink
Add Generator design pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
everysoftware committed Nov 28, 2024
1 parent f35a4cf commit 4d71c84
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 17 deletions.
142 changes: 142 additions & 0 deletions src/behavioral/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
Conceptually, an iterator is a mechanism for traversing data element by element, while a generator allows you to
lazily create a result during iteration.
In Python, any function that uses the yield keyword is a generator function. When called, it returns a generator object
that can be used to control the execution of the generator function. The generator object is both an iterator and an
iterable, so you can use it in for loops and pass it to any function that expects an iterable.
"""

from __future__ import annotations

import io
from abc import ABC, abstractmethod
from sqlite3 import Connection
from types import TracebackType
from typing import Any, Self, Generator as PythonGenerator

from src.behavioral.iterator import Iterator


class Generator[YieldT, SendT, ReturnT](Iterator[YieldT], ABC):
@abstractmethod
def send(self, value: SendT) -> YieldT: ...

@abstractmethod
def throw(
self,
exc_type: type[BaseException],
exc_val: BaseException | None = None,
tb: TracebackType | None = None,
) -> Self: ...

@abstractmethod
def close(self) -> None: ...


# Basic usage


# Type-hinting equivalent to Iterator[int]
def gen_pow(n: int) -> PythonGenerator[int, None, None]:
yield n**0
yield n**1
yield n**2
yield n**3


# Delegate to another generator
def gen_pow_delegation(n: int) -> PythonGenerator[int, None, None]:
yield from gen_pow(n)


def count(n: int) -> PythonGenerator[int, None, None]:
for i in range(n):
yield i


def count_delegation(n: int) -> PythonGenerator[int, None, None]:
yield from range(n)


def gen_sum() -> PythonGenerator[int, int, int]:
total = 0
while True:
try:
value = yield total
if value is not None:
total += value
except StopIteration:
return total


# Class-based generator


class SumGenerator(Generator[int, int, int]):
def __init__(self) -> None:
self._total = 0
self._closed = False

def send(self, value: int) -> int:
if self._closed:
raise StopIteration(self._total)
self._total += value
return self._total

def throw(
self,
exc_type: type[BaseException],
exc_val: BaseException | None = None,
tb: TracebackType | None = None,
) -> SumGenerator:
return self

def close(self) -> None:
self._closed = True

def __iter__(self) -> SumGenerator:
return self

def __next__(self) -> int:
return self.send(0)


def gen_line(
output: io.StringIO, state: dict[str, Any]
) -> PythonGenerator[str, None, None]:
# lines
try:
while True:
line = output.readline().rstrip()
if not line:
break
yield line
finally:
state["closed"] = True
output.close()


class CommitException(Exception):
pass


class AbortException(Exception):
pass


def db_session(
db: Connection, sql: str
) -> PythonGenerator[None, tuple[Any, ...], None]:
cursor = db.cursor()
try:
while True:
try:
row = yield
cursor.execute(sql, row)
except CommitException:
db.commit()
except AbortException:
db.rollback()
finally:
db.rollback()
29 changes: 12 additions & 17 deletions src/behavioral/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,39 @@
without exposing its underlying representation.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Sequence, Self
from typing import Sequence, Self


class Iterator(ABC):
@abstractmethod
def __next__(self) -> Any: ...
class Iterable[T](ABC):
@abstractmethod
def has_next(self) -> bool: ...
def __iter__(self) -> Iterator[T]: ...


class Iterator[T](Iterable[T], ABC):
@abstractmethod
def __iter__(self) -> Self: ...
def __next__(self) -> T: ...


class NameIterator(Iterator):
class NameIterator(Iterator[str]):
def __init__(self, names: Sequence[str]) -> None:
self._names = names
self._position = 0

def __next__(self) -> str:
if not self.has_next():
if not self._position < len(self._names):
raise StopIteration
name = self._names[self._position]
self._position += 1
return name

def has_next(self) -> bool:
return self._position < len(self._names)

def __iter__(self) -> Self:
return self


class Iterable(ABC):
@abstractmethod
def __iter__(self) -> Iterator: ...


class NameCollection(Iterable):
class NameCollection(Iterable[str]):
def __init__(self) -> None:
self._names: list[str] = []

Expand Down
Empty file removed src/behavioral/pub_sub.py
Empty file.
92 changes: 92 additions & 0 deletions tests/test_behavioral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
import io
import sqlite3
from typing import Generator as PythonGenerator

import pytest

from src.behavioral.command import Bank, Account
from src.behavioral.cor import LeaveRequest, Manager, Director, TeamLead
from src.behavioral.generator import (
gen_pow,
gen_pow_delegation,
count,
count_delegation,
gen_sum,
SumGenerator,
gen_line,
db_session,
CommitException,
)
from src.behavioral.interpreter import Subtract, Number, Add
from src.behavioral.iterator import NameCollection
from src.behavioral.mediator import ChatRoom, Participant
Expand Down Expand Up @@ -258,3 +275,78 @@ def test_observer() -> None:
sensor.set_temperature(18)
assert ac.temperature == 18
assert heater.temperature == 18


def test_gen_pow() -> None:
g = gen_pow(2)
assert next(g) == 1
assert next(g) == 2
assert next(g) == 4
assert next(g) == 8
with pytest.raises(StopIteration):
next(g)

assert list(gen_pow(2)) == [1, 2, 4, 8]
assert list(gen_pow(2)) == list(gen_pow_delegation(2))
assert list(count(3)) == [0, 1, 2]
assert list(count(3)) == list(count_delegation(3))


@pytest.mark.parametrize(
"impl",
[
gen_sum,
SumGenerator,
],
)
def test_gen_sum(impl: type[PythonGenerator[int, int, int]]) -> None:
g = impl()

assert next(g) == 0
# SendT == int, YieldT == int
assert g.send(1) == 1
assert g.send(2) == 3
assert g.send(3) == 6
assert g.send(4) == 10
assert g.send(5) == 15
g.close()
with pytest.raises(StopIteration) as e:
next(g)
# ReturnT == int
assert e.value.value == 15


def test_gen_line() -> None:
# A file-like object
output = io.StringIO()
output.write("First line\n")
output.write("Second line\n")
output.write("Third line\n")
output.seek(0)
state = {"closed": False}

g = gen_line(output, state)
assert next(g) == "First line"
assert next(g) == "Second line"
assert next(g) == "Third line"
g.close()


def test_db_session() -> None:
db_url = ":memory:"
conn = sqlite3.connect(db_url)
with conn:
conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)")

session = db_session(conn, "INSERT INTO test VALUES (?)")
next(session)
session.send((12,))
session.send((42,))
session.send((96,))
session.throw(CommitException)

with conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM test")
rows = cursor.fetchall()
assert rows == [(12,), (42,), (96,)]

0 comments on commit 4d71c84

Please sign in to comment.