Skip to content

Commit 7e656ce

Browse files
committed
Fix per-carrier booking class in verify/scrape commands (closes #2)
AA uses H class for oneworld Explorer business, not D. The verify and scrape commands hardcoded D for all carriers, producing false negatives on AA segments. Now resolves booking class per carrier from carriers.yaml (AA=H, others=D) with --class flag as optional override. - New rtw/carriers.py shared utility with get_booking_class() - DClassVerifier resolves per-segment, updated cache keys - DClassResult.display_code uses actual class letter (H9/D9) - CLI --class default changed to auto per-carrier lookup - 37 new tests (822 total, 0 failures) Co-Authored-By: Claude Opus 4.6 <[email protected]>
1 parent a3e874a commit 7e656ce

File tree

9 files changed

+349
-37
lines changed

9 files changed

+349
-37
lines changed

rtw/booking.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,28 +68,15 @@ def __init__(self) -> None:
6868
def _get_booking_class(self, carrier: Optional[str], cabin: CabinClass) -> Optional[str]:
6969
"""Determine the booking class for a carrier/cabin combination.
7070
71-
Business: AA -> H (special case), others -> D (from carriers.yaml).
72-
Economy: L for most carriers.
73-
First: A for most carriers.
74-
Surface segments (carrier=None): returns None.
71+
Delegates to shared utility in rtw.carriers. Returns None for
72+
surface segments (carrier=None).
7573
"""
7674
if carrier is None:
7775
return None
7876

79-
carrier = carrier.upper()
77+
from rtw.carriers import get_booking_class
8078

81-
if cabin == CabinClass.BUSINESS:
82-
# Use rtw_booking_class from carriers.yaml (AA=H, others=D)
83-
carrier_data = self._carriers.get(carrier, {})
84-
return carrier_data.get("rtw_booking_class", "D")
85-
86-
if cabin == CabinClass.ECONOMY:
87-
return "L"
88-
89-
if cabin == CabinClass.FIRST:
90-
return "A"
91-
92-
return "D"
79+
return get_booking_class(carrier, cabin)
9380

9481
def _is_same_city(self, airport1: str, airport2: str) -> bool:
9582
"""Check if two airports are in the same city group."""

rtw/carriers.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Shared carrier booking class resolution.
2+
3+
Resolves the correct booking class for a carrier/cabin combination
4+
using data from carriers.yaml. AA uses H class for oneworld Explorer
5+
business; all other carriers use D.
6+
"""
7+
8+
from pathlib import Path
9+
from typing import Optional
10+
11+
import yaml
12+
13+
from rtw.models import CabinClass
14+
15+
_DATA_DIR = Path(__file__).parent / "data"
16+
with open(_DATA_DIR / "carriers.yaml") as f:
17+
_CARRIERS: dict = yaml.safe_load(f)
18+
19+
20+
def get_booking_class(carrier: Optional[str], cabin: CabinClass) -> str:
21+
"""Return the booking class for a carrier/cabin combination.
22+
23+
Business: AA -> H (from carriers.yaml rtw_booking_class), others -> D.
24+
Economy: L for all carriers.
25+
First: A for all carriers.
26+
Surface segments (carrier=None): returns D as safe default.
27+
28+
Always returns a concrete single-letter string, never None.
29+
"""
30+
if carrier is None:
31+
return "D"
32+
33+
carrier = carrier.upper()
34+
35+
if cabin == CabinClass.BUSINESS:
36+
carrier_data = _CARRIERS.get(carrier, {})
37+
return carrier_data.get("rtw_booking_class", "D")
38+
39+
if cabin == CabinClass.ECONOMY:
40+
return "L"
41+
42+
if cabin == CabinClass.FIRST:
43+
return "A"
44+
45+
return "D"

rtw/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def scrape_prices(
648648
@scrape_app.command(name="availability")
649649
def scrape_availability(
650650
file: str = typer.Argument(help="Path to itinerary YAML file"),
651-
booking_class: str = typer.Option("D", "--class", "-c", help="Booking class to check"),
651+
booking_class: Optional[str] = typer.Option(None, "--class", "-c", help="Override booking class (default: auto per carrier, AA=H, others=D)"),
652652
verbose: VerboseFlag = False,
653653
quiet: QuietFlag = False,
654654
) -> None:
@@ -1059,7 +1059,7 @@ def verify(
10591059
option_ids: Annotated[
10601060
Optional[list[int]], typer.Argument(help="Option IDs to verify (1-based). Omit for top 3.")
10611061
] = None,
1062-
booking_class: Annotated[str, typer.Option("--class", "-c", help="Booking class")] = "D",
1062+
booking_class: Annotated[Optional[str], typer.Option("--class", "-c", help="Override booking class (default: auto per carrier, AA=H, others=D)")] = None,
10631063
no_cache: Annotated[bool, typer.Option("--no-cache", help="Skip cache")] = False,
10641064
json: JsonFlag = False,
10651065
plain: PlainFlag = False,
@@ -1128,6 +1128,7 @@ def verify(
11281128
cache=ScrapeCache(),
11291129
booking_class=booking_class,
11301130
)
1131+
# Note: booking_class=None means auto per-carrier (AA=H, others=D)
11311132

11321133
# Convert and verify with Rich progress
11331134
results = []

rtw/scraper/batch.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,38 +130,45 @@ def _try_playwright_price(origin, dest, seg_date, cabin):
130130

131131
def check_itinerary_availability(
132132
itinerary: Itinerary,
133-
booking_class: str = "D",
133+
booking_class: Optional[str] = None,
134134
) -> list[Optional[dict]]:
135135
"""Check award availability for all flown segments.
136136
137137
Args:
138138
itinerary: The RTW itinerary to check.
139-
booking_class: Booking class to check (default "D" for business award).
139+
booking_class: Override booking class for all segments. When None
140+
(default), resolves per carrier from carriers.yaml (AA=H, others=D).
140141
141142
Returns:
142143
List of availability dicts (or None) for each segment.
143144
Surface segments and segments without carriers get None. Never raises.
144145
"""
146+
from rtw.carriers import get_booking_class
147+
145148
scraper = ExpertFlyerScraper()
146149

147150
if not scraper.credentials_available():
148151
logger.info("ExpertFlyer credentials not available - returning empty results")
149152
return [None] * len(itinerary.segments)
150153

154+
cabin = itinerary.ticket.cabin
155+
151156
results: list[Optional[dict]] = []
152157

153158
for seg in itinerary.segments:
154159
if seg.is_surface or seg.carrier is None or seg.date is None:
155160
results.append(None)
156161
continue
157162

163+
seg_bc = booking_class if booking_class is not None else get_booking_class(seg.carrier, cabin)
164+
158165
try:
159166
avail = scraper.check_availability(
160167
origin=seg.from_airport,
161168
dest=seg.to_airport,
162169
date=seg.date,
163170
carrier=seg.carrier,
164-
booking_class=booking_class,
171+
booking_class=seg_bc,
165172
)
166173
results.append(avail)
167174
except Exception as exc:

rtw/verify/models.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class FlightAvailability(BaseModel):
4141

4242

4343
class DClassResult(BaseModel):
44-
"""Result of a D-class check for a single flight segment."""
44+
"""Result of an availability check for a single flight segment."""
4545

4646
status: DClassStatus
4747
seats: int = Field(default=0, ge=0, le=9)
@@ -50,6 +50,7 @@ class DClassResult(BaseModel):
5050
origin: str = Field(min_length=3, max_length=3)
5151
destination: str = Field(min_length=3, max_length=3)
5252
target_date: datetime.date
53+
booking_class: str = "D"
5354
checked_at: datetime.datetime = Field(
5455
default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
5556
)
@@ -64,7 +65,7 @@ def available(self) -> bool:
6465

6566
@property
6667
def available_flights(self) -> list[FlightAvailability]:
67-
"""Flights with D-class seats > 0, sorted by seats desc then departure."""
68+
"""Flights with seats > 0, sorted by seats desc then departure."""
6869
avail = [f for f in self.flights if f.seats > 0]
6970
return sorted(avail, key=lambda f: (-f.seats, f.depart_time or ""))
7071

@@ -78,14 +79,15 @@ def available_count(self) -> int:
7879

7980
@property
8081
def display_code(self) -> str:
81-
"""Short display code: D9 (3 avl), D0, D?, D!"""
82+
"""Short display code: H9 (3 avl), D0, D?, H!"""
83+
bc = self.booking_class
8284
if self.status == DClassStatus.ERROR:
83-
return "D!"
85+
return f"{bc}!"
8486
if self.status == DClassStatus.UNKNOWN:
85-
return "D?"
87+
return f"{bc}?"
8688
if self.flights:
87-
return f"D{self.seats} ({self.available_count} avl)"
88-
return f"D{self.seats}"
89+
return f"{bc}{self.seats} ({self.available_count} avl)"
90+
return f"{bc}{self.seats}"
8991

9092
@property
9193
def best_alternate(self) -> Optional[AlternateDateResult]:

rtw/verify/verifier.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
"""D-class verification orchestrator.
1+
"""Availability verification orchestrator.
22
33
Coordinates the scraper, cache, and progress reporting to verify
4-
D-class availability across all flown segments of an itinerary option.
4+
award class availability across all flown segments of an itinerary option.
5+
Uses per-carrier booking class resolution (AA=H, others=D for business).
56
"""
67

78
import logging
89
import time
910
from typing import Optional
1011

12+
from rtw.carriers import get_booking_class
13+
from rtw.models import CabinClass
1114
from rtw.scraper.cache import ScrapeCache
1215
from rtw.scraper.expertflyer import ExpertFlyerScraper, SessionExpiredError
1316
from rtw.verify.models import (
@@ -26,28 +29,44 @@
2629

2730

2831
class DClassVerifier:
29-
"""Verify D-class availability for itinerary segments.
32+
"""Verify award class availability for itinerary segments.
3033
3134
Checks each flown segment against ExpertFlyer, using the cache
3235
to avoid redundant queries. Surface segments are skipped.
36+
37+
Resolves booking class per carrier (AA=H, others=D for business)
38+
unless an explicit override is provided.
3339
"""
3440

3541
def __init__(
3642
self,
3743
scraper: ExpertFlyerScraper,
3844
cache: Optional[ScrapeCache] = None,
39-
booking_class: str = "D",
45+
booking_class: Optional[str] = None,
46+
cabin: CabinClass = CabinClass.BUSINESS,
4047
) -> None:
4148
self.scraper = scraper
4249
self.cache = cache or ScrapeCache()
43-
self.booking_class = booking_class
50+
self._booking_class_override = booking_class
51+
self.cabin = cabin
4452
self._session_expired = False
4553

54+
def _get_segment_booking_class(self, seg: SegmentVerification) -> str:
55+
"""Resolve the booking class for a segment.
56+
57+
If an override was set, use it for all segments.
58+
Otherwise, look up per carrier from carriers.yaml.
59+
"""
60+
if self._booking_class_override is not None:
61+
return self._booking_class_override
62+
return get_booking_class(seg.carrier, self.cabin)
63+
4664
def _cache_key(self, seg: SegmentVerification) -> str:
4765
"""Build cache key for a segment."""
66+
bc = self._get_segment_booking_class(seg)
4867
return (
4968
f"{_CACHE_KEY_PREFIX}_{seg.carrier}_{seg.origin}_"
50-
f"{seg.destination}_{seg.target_date}_{self.booking_class}"
69+
f"{seg.destination}_{seg.target_date}_{bc}"
5170
)
5271

5372
def _check_cache(self, seg: SegmentVerification) -> Optional[DClassResult]:
@@ -111,6 +130,7 @@ def verify_option(
111130
origin=seg.origin,
112131
destination=seg.destination,
113132
target_date=seg.target_date,
133+
booking_class=self._get_segment_booking_class(seg),
114134
error_message="Session expired during batch",
115135
)
116136
result.segments.append(verified)
@@ -132,13 +152,14 @@ def verify_option(
132152

133153
# Call scraper
134154
try:
155+
seg_bc = self._get_segment_booking_class(seg)
135156
start = time.time()
136157
dclass = self.scraper.check_availability(
137158
origin=seg.origin,
138159
dest=seg.destination,
139160
date=seg.target_date,
140161
carrier=seg.carrier or "",
141-
booking_class=self.booking_class,
162+
booking_class=seg_bc,
142163
)
143164
elapsed = time.time() - start
144165
logger.debug(
@@ -150,6 +171,7 @@ def verify_option(
150171
)
151172

152173
if dclass:
174+
dclass.booking_class = seg_bc
153175
verified.dclass = dclass
154176
self._store_cache(seg, dclass)
155177
else:
@@ -160,6 +182,7 @@ def verify_option(
160182
origin=seg.origin,
161183
destination=seg.destination,
162184
target_date=seg.target_date,
185+
booking_class=seg_bc,
163186
error_message="Scraper returned None (no session?)",
164187
)
165188

@@ -172,6 +195,7 @@ def verify_option(
172195
origin=seg.origin,
173196
destination=seg.destination,
174197
target_date=seg.target_date,
198+
booking_class=self._get_segment_booking_class(seg),
175199
error_message=str(exc),
176200
)
177201
except Exception as exc:
@@ -182,6 +206,7 @@ def verify_option(
182206
origin=seg.origin,
183207
destination=seg.destination,
184208
target_date=seg.target_date,
209+
booking_class=self._get_segment_booking_class(seg),
185210
error_message=str(exc),
186211
)
187212

tests/test_carriers.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Tests for shared carrier booking class resolution."""
2+
3+
import pytest
4+
5+
from rtw.carriers import get_booking_class
6+
from rtw.models import CabinClass
7+
8+
9+
class TestGetBookingClass:
10+
"""Test get_booking_class() for all carrier/cabin combinations."""
11+
12+
def test_aa_business_returns_h(self):
13+
assert get_booking_class("AA", CabinClass.BUSINESS) == "H"
14+
15+
def test_aa_lowercase_returns_h(self):
16+
assert get_booking_class("aa", CabinClass.BUSINESS) == "H"
17+
18+
def test_ba_business_returns_d(self):
19+
assert get_booking_class("BA", CabinClass.BUSINESS) == "D"
20+
21+
def test_cx_business_returns_d(self):
22+
assert get_booking_class("CX", CabinClass.BUSINESS) == "D"
23+
24+
def test_qr_business_returns_d(self):
25+
assert get_booking_class("QR", CabinClass.BUSINESS) == "D"
26+
27+
def test_jl_business_returns_d(self):
28+
assert get_booking_class("JL", CabinClass.BUSINESS) == "D"
29+
30+
def test_qf_business_returns_d(self):
31+
assert get_booking_class("QF", CabinClass.BUSINESS) == "D"
32+
33+
def test_unknown_carrier_business_returns_d(self):
34+
assert get_booking_class("ZZ", CabinClass.BUSINESS) == "D"
35+
36+
def test_economy_returns_l(self):
37+
assert get_booking_class("AA", CabinClass.ECONOMY) == "L"
38+
39+
def test_economy_any_carrier_returns_l(self):
40+
assert get_booking_class("BA", CabinClass.ECONOMY) == "L"
41+
42+
def test_first_returns_a(self):
43+
assert get_booking_class("AA", CabinClass.FIRST) == "A"
44+
45+
def test_first_any_carrier_returns_a(self):
46+
assert get_booking_class("QF", CabinClass.FIRST) == "A"
47+
48+
def test_none_carrier_returns_d(self):
49+
"""Surface segments (carrier=None) return safe default."""
50+
assert get_booking_class(None, CabinClass.BUSINESS) == "D"
51+
52+
def test_never_returns_none(self):
53+
"""Function always returns a string, never None."""
54+
for carrier in [None, "AA", "BA", "ZZ"]:
55+
for cabin in CabinClass:
56+
result = get_booking_class(carrier, cabin)
57+
assert isinstance(result, str)
58+
assert len(result) == 1

0 commit comments

Comments
 (0)