diff --git a/rtw/airports.py b/rtw/airports.py new file mode 100644 index 0000000..92026d6 --- /dev/null +++ b/rtw/airports.py @@ -0,0 +1,38 @@ +"""Shared airportsdata loader with fail-fast. + +Loads the airportsdata IATA database exactly once at import time. +All consumer modules should import from here instead of loading +airportsdata directly. + +If airportsdata is not installed, exits immediately with code 2 +and a clear error message — without it, distances become 0, +continents become None, and every downstream calculation is wrong. +""" + +import sys + +try: + import airportsdata + + airports_db: dict = airportsdata.load("IATA") +except ImportError: + try: + from rich.console import Console + from rich.panel import Panel + + console = Console(stderr=True) + console.print( + Panel( + "Required library 'airportsdata' is not available.\n\n" + "Fix: pip install airportsdata", + title="Missing Dependency", + border_style="red", + ) + ) + except ImportError: + print( + "Error: Required library 'airportsdata' is not available.\n" + "Fix: pip install airportsdata", + file=sys.stderr, + ) + sys.exit(2) diff --git a/rtw/cli.py b/rtw/cli.py index 02d96c8..6b12e3b 100644 --- a/rtw/cli.py +++ b/rtw/cli.py @@ -100,11 +100,10 @@ def _setup_logging(verbose: bool = False, quiet: bool = False) -> None: def _known_airport_codes() -> list[str]: """Return a list of known airport codes for fuzzy matching.""" try: - import airportsdata + from rtw.airports import airports_db - db = airportsdata.load("IATA") - return list(db.keys()) - except Exception: + return list(airports_db.keys()) + except SystemExit: return [] @@ -259,6 +258,11 @@ def cost( except typer.BadParameter: raise except Exception as exc: + from rtw.cost import FareLookupError + + if isinstance(exc, FareLookupError): + _error_panel(str(exc)) + raise typer.Exit(code=2) _error_panel(str(exc)) raise typer.Exit(code=2) diff --git a/rtw/continents.py b/rtw/continents.py index 70f3558..65b7794 100644 --- a/rtw/continents.py +++ b/rtw/continents.py @@ -5,13 +5,7 @@ import yaml -try: - import airportsdata - - _airports_db = airportsdata.load("IATA") -except Exception: - _airports_db = {} - +from rtw.airports import airports_db as _airports_db from rtw.models import Continent, TariffConference, CONTINENT_TO_TC _DATA_DIR = Path(__file__).parent / "data" diff --git a/rtw/cost.py b/rtw/cost.py index c802dae..a44d305 100644 --- a/rtw/cost.py +++ b/rtw/cost.py @@ -13,54 +13,20 @@ import yaml +from rtw.airports import airports_db from rtw.models import CostEstimate, Itinerary, TicketType _DATA_DIR = Path(__file__).parent / "data" -# Major US airport codes for AA domestic zero-YQ rule + +class FareLookupError(Exception): + """Raised when a fare lookup returns $0 (missing data).""" + + pass + +# US airport codes for AA domestic zero-YQ rule (dynamic from airportsdata) _US_AIRPORTS = { - "JFK", - "EWR", - "LGA", - "LAX", - "SFO", - "ORD", - "DFW", - "MIA", - "ATL", - "SEA", - "BOS", - "DEN", - "PHX", - "MCO", - "IAD", - "IAH", - "CLT", - "PHL", - "SAN", - "AUS", - "MSP", - "DTW", - "SLC", - "HNL", - "OGG", - "TPA", - "FLL", - "BWI", - "DCA", - "STL", - "PDX", - "BNA", - "RDU", - "CLE", - "PIT", - "IND", - "MCI", - "OAK", - "SJC", - "SMF", - "ABQ", - "ANC", + code for code, info in airports_db.items() if info.get("country") == "US" } @@ -177,6 +143,11 @@ def estimate_total(self, itinerary: Itinerary, plating_carrier: str = "AA") -> C passengers = itinerary.ticket.passengers base_fare = self.get_base_fare(origin, ticket_type) + if base_fare == 0.0: + raise FareLookupError( + f"No fare data for origin={origin} ticket_type={ticket_type.value}. " + f"Check rtw/data/fares.yaml." + ) total_yq = self.estimate_surcharges(itinerary, plating_carrier) per_person = base_fare + total_yq total_all = per_person * passengers diff --git a/rtw/data/carriers.yaml b/rtw/data/carriers.yaml index dff86f7..15e451d 100644 --- a/rtw/data/carriers.yaml +++ b/rtw/data/carriers.yaml @@ -141,15 +141,25 @@ UL: rtw_booking_class: D notes: "D class = 25% of distance for NTP." +WY: + name: Oman Air + alliance: oneworld + eligible: true + ntp_method: distance + yq_tier: low + yq_estimate_per_segment: 90 + rtw_booking_class: D + notes: "Joined oneworld as full member June 2025. Distance-based NTP. D class = 12.5% (lowest in alliance for business)." + S7: name: S7 Airlines alliance: oneworld - eligible: true + eligible: false ntp_method: distance yq_tier: low yq_estimate_per_segment: 40 rtw_booking_class: D - notes: "Russian carrier. Limited RTW utility due to sanctions." + notes: "Russian carrier. Suspended from oneworld Explorer due to EU/US/UK sanctions on Russian aviation. Technically still a oneworld member but flights cannot be ticketed on RTW itineraries." # Ineligible carriers (for reference) LA: diff --git a/rtw/data/fares.yaml b/rtw/data/fares.yaml index 5c60388..09b71cd 100644 --- a/rtw/data/fares.yaml +++ b/rtw/data/fares.yaml @@ -1,6 +1,11 @@ # oneworld Explorer base fares by origin city (approximate USD) # Based on FlyerTalk data and fare filings as of 2025-2026 # Note: Fares fluctuate with currency rates and filing changes +# +# AONE (first class) fares: estimated at 1.6x DONE multiplier, +# rounded to nearest $100. Actual filed fares vary by origin currency +# and are not publicly published per-origin. The 1.6x ratio is derived +# from ex-UK GBP filings where AONE/DONE ratios range 1.4-1.8x. origins: CAI: @@ -8,6 +13,10 @@ origins: currency: EGP notes: "Historically cheapest. EGP devaluation advantage." fares: + AONE3: 5600 + AONE4: 6400 + AONE5: 7000 + AONE6: 8800 DONE3: 3500 DONE4: 4000 DONE5: 4400 @@ -22,6 +31,10 @@ origins: currency: EUR notes: "Norway filing advantage. Easy positioning from London." fares: + AONE3: 7700 + AONE4: 8600 + AONE5: 9300 + AONE6: 10400 DONE3: 4800 DONE4: 5400 DONE5: 5800 @@ -36,6 +49,10 @@ origins: currency: ZAR notes: "ZAR weakness. Good if Africa is on route." fares: + AONE3: 6400 + AONE4: 8000 + AONE5: 9100 + AONE6: 10700 DONE3: 4000 DONE4: 5000 DONE5: 5700 @@ -50,6 +67,10 @@ origins: currency: JPY notes: "Higher taxes/YQ than CAI/OSL." fares: + AONE3: 8800 + AONE4: 10200 + AONE5: 11600 + AONE6: 13600 DONE3: 5500 DONE4: 6360 DONE5: 7260 @@ -64,6 +85,10 @@ origins: currency: LKR notes: "SriLankan Airlines is oneworld. Occasionally cheap." fares: + AONE3: 7200 + AONE4: 8300 + AONE5: 9600 + AONE6: 11200 DONE3: 4500 DONE4: 5200 DONE5: 6000 @@ -78,6 +103,10 @@ origins: currency: GBP notes: "High UK departure taxes on premium long-haul." fares: + AONE3: 11200 + AONE4: 12800 + AONE5: 14400 + AONE6: 16800 DONE3: 7000 DONE4: 8000 DONE5: 9000 @@ -92,6 +121,10 @@ origins: currency: USD notes: "Most expensive origin. Consider positioning to CAI/OSL." fares: + AONE3: 14400 + AONE4: 16800 + AONE5: 19200 + AONE6: 22600 DONE3: 9000 DONE4: 10500 DONE5: 12000 @@ -106,6 +139,10 @@ origins: currency: AUD notes: "Ex-Japan is roughly half the price for same itinerary." fares: + AONE3: 12000 + AONE4: 14100 + AONE5: 16000 + AONE6: 19200 DONE3: 7500 DONE4: 8800 DONE5: 10000 diff --git a/rtw/data/ntp_rates.yaml b/rtw/data/ntp_rates.yaml index 5c5e328..737ce52 100644 --- a/rtw/data/ntp_rates.yaml +++ b/rtw/data/ntp_rates.yaml @@ -219,6 +219,23 @@ distance_based: S: 2 V: 2 + WY: + D: 12.5 + J: 12.5 + C: 12.5 + I: 6 + Z: 6 + Y: 3.5 + B: 3.5 + H: 3.5 + K: 2 + L: 2 + M: 2 + N: 2 + Q: 2 + S: 2 + V: 2 + # BA bonus NTP (per segment, BA-marketed flights only) # Permanent from 25 Nov 2025 ba_bonus: diff --git a/rtw/distance.py b/rtw/distance.py index 804cb95..c8a626d 100644 --- a/rtw/distance.py +++ b/rtw/distance.py @@ -1,14 +1,9 @@ """Great-circle distance between airports using IATA codes.""" -try: - import airportsdata - - _airports_db = airportsdata.load("IATA") -except Exception: - _airports_db = {} - from haversine import haversine, Unit +from rtw.airports import airports_db as _airports_db + class DistanceCalculator: """Calculate great-circle distances between airports.""" diff --git a/rtw/search/query.py b/rtw/search/query.py index 38858b2..6830e31 100644 --- a/rtw/search/query.py +++ b/rtw/search/query.py @@ -6,16 +6,10 @@ from datetime import date from typing import Optional +from rtw.airports import airports_db as _airports_db from rtw.models import CabinClass, TicketType from rtw.search.models import SearchQuery -try: - import airportsdata - - _airports_db = airportsdata.load("IATA") -except Exception: - _airports_db = {} - def _fuzzy_suggestion(code: str) -> str: """Suggest close airport codes.""" diff --git a/tests/conftest.py b/tests/conftest.py index 60c3679..94ec80c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -47,3 +47,15 @@ def too_many_segments_itinerary(load_yaml): def minimal_valid_itinerary(load_yaml): """Load the minimal valid routing.""" return load_yaml("minimal_valid.yaml") + + +@pytest.fixture +def done3_itinerary(load_yaml): + """Load the DONE3 business 3-continent routing.""" + return load_yaml("done3_cai_eastbound.yaml") + + +@pytest.fixture +def lone3_itinerary(load_yaml): + """Load the LONE3 economy 3-continent routing.""" + return load_yaml("lone3_osl_westbound.yaml") diff --git a/tests/fixtures/done3_cai_eastbound.yaml b/tests/fixtures/done3_cai_eastbound.yaml new file mode 100644 index 0000000..9f93545 --- /dev/null +++ b/tests/fixtures/done3_cai_eastbound.yaml @@ -0,0 +1,38 @@ +# DONE3 - Business Class, 3 continents ex-Cairo (Eastbound) +# Covers: EU/ME, Asia, SWP — minimum continent count +# Tests: DONE3 fare, 3-continent routing, business class +ticket: + type: DONE3 + cabin: business + origin: CAI + passengers: 1 + +segments: + - from: CAI + to: AMM + carrier: RJ + type: stopover + - from: AMM + to: DOH + carrier: QR + type: transit + - from: DOH + to: NRT + carrier: QR + type: stopover + - from: NRT + to: SYD + carrier: JL + type: stopover + - from: SYD + to: SIN + carrier: QF + type: stopover + - from: SIN + to: DOH + carrier: QR + type: transit + - from: DOH + to: CAI + carrier: QR + type: final diff --git a/tests/fixtures/lone3_osl_westbound.yaml b/tests/fixtures/lone3_osl_westbound.yaml new file mode 100644 index 0000000..e90a266 --- /dev/null +++ b/tests/fixtures/lone3_osl_westbound.yaml @@ -0,0 +1,34 @@ +# LONE3 - Economy Class, 3 continents ex-Oslo (Westbound) +# Covers: EU/ME, N.America, Asia — budget economy routing +# Tests: LONE3 fare, economy cabin, westbound direction +ticket: + type: LONE3 + cabin: economy + origin: OSL + passengers: 1 + +segments: + - from: OSL + to: LHR + carrier: BA + type: transit + - from: LHR + to: JFK + carrier: BA + type: stopover + - from: JFK + to: LAX + carrier: AA + type: stopover + - from: LAX + to: NRT + carrier: JL + type: stopover + - from: NRT + to: HEL + carrier: AY + type: stopover + - from: HEL + to: OSL + carrier: AY + type: final diff --git a/tests/test_airports.py b/tests/test_airports.py new file mode 100644 index 0000000..6c36898 --- /dev/null +++ b/tests/test_airports.py @@ -0,0 +1,14 @@ +"""Tests for rtw.airports shared module.""" + +from rtw.airports import airports_db + + +class TestAirportsDB: + def test_airports_db_is_populated(self): + """Airport database should have 5000+ entries.""" + assert len(airports_db) > 5000 + + def test_airports_db_contains_common_codes(self): + """Database should contain major hub codes.""" + for code in ["JFK", "LHR", "NRT", "SYD", "CAI"]: + assert code in airports_db, f"{code} missing from airports_db" diff --git a/tests/test_cost.py b/tests/test_cost.py index 5b8837f..f9bd6a7 100644 --- a/tests/test_cost.py +++ b/tests/test_cost.py @@ -3,7 +3,7 @@ import pytest from rtw.models import CostEstimate, Itinerary, TicketType -from rtw.cost import CostEstimator +from rtw.cost import CostEstimator, FareLookupError @pytest.fixture @@ -11,13 +11,13 @@ def estimator(): return CostEstimator() -def _make_itinerary(segments, origin="CAI", ticket_type="DONE4", passengers=1): +def _make_itinerary(segments, origin="CAI", ticket_type="DONE4", passengers=1, cabin="business"): """Build a minimal Itinerary for testing.""" return Itinerary.model_validate( { "ticket": { "type": ticket_type, - "cabin": "business", + "cabin": cabin, "origin": origin, "passengers": passengers, }, @@ -348,3 +348,125 @@ def test_plating_affects_notes(self, estimator, v3_itinerary): itin = Itinerary.model_validate(v3_itinerary) result = estimator.estimate_total(itin, plating_carrier="AA") assert "flexibility" in result.notes.lower() or "change" in result.notes.lower() + + +# ------------------------------------------------------------------ +# AONE (first class) fare tests +# ------------------------------------------------------------------ + + +class TestAONEFares: + """AONE first class fares from fares.yaml.""" + + _ORIGINS = ["CAI", "OSL", "JNB", "NRT", "CMB", "LHR", "JFK", "SYD"] + _AONE_TYPES = [TicketType.AONE3, TicketType.AONE4, TicketType.AONE5, TicketType.AONE6] + + def test_aone_fares_nonzero(self, estimator): + """AONE4 fare should be > 0 for all 8 origins.""" + for origin in self._ORIGINS: + fare = estimator.get_base_fare(origin, TicketType.AONE4) + assert fare > 0, f"AONE4 fare for {origin} is $0" + + def test_aone_fares_all_origins_all_types(self, estimator): + """All 32 AONE entries should be > 0.""" + for origin in self._ORIGINS: + for tt in self._AONE_TYPES: + fare = estimator.get_base_fare(origin, tt) + assert fare > 0, f"{tt.value} fare for {origin} is $0" + + def test_aone_more_than_done(self, estimator): + """AONE should be more expensive than DONE for same origin+continent count.""" + done_types = [TicketType.DONE3, TicketType.DONE4, TicketType.DONE5, TicketType.DONE6] + for origin in self._ORIGINS: + for aone, done in zip(self._AONE_TYPES, done_types): + aone_fare = estimator.get_base_fare(origin, aone) + done_fare = estimator.get_base_fare(origin, done) + assert aone_fare > done_fare, ( + f"{origin}: {aone.value}=${aone_fare} should be > {done.value}=${done_fare}" + ) + + def test_aone_done_ratio_in_range(self, estimator): + """AONE/DONE ratio should be in 1.4-1.8x range.""" + done_types = [TicketType.DONE3, TicketType.DONE4, TicketType.DONE5, TicketType.DONE6] + for origin in self._ORIGINS: + for aone, done in zip(self._AONE_TYPES, done_types): + aone_fare = estimator.get_base_fare(origin, aone) + done_fare = estimator.get_base_fare(origin, done) + ratio = aone_fare / done_fare + assert 1.4 <= ratio <= 1.8, ( + f"{origin}: {aone.value}/{done.value} ratio = {ratio:.2f}, expected 1.4-1.8" + ) + + +# ------------------------------------------------------------------ +# Zero-fare guard tests +# ------------------------------------------------------------------ + + +class TestZeroFareGuard: + """FareLookupError when fare data is missing.""" + + def test_zero_fare_raises_fare_lookup_error(self, estimator): + """estimate_total() should raise FareLookupError when base_fare == $0.""" + itin = _make_itinerary( + [{"from": "ZZZ", "to": "YYY", "carrier": "QR"}], + origin="ZZZ", + ) + with pytest.raises(FareLookupError): + estimator.estimate_total(itin) + + def test_fare_lookup_error_message_contains_origin(self, estimator): + """FareLookupError message should contain the origin code and ticket type.""" + itin = _make_itinerary( + [{"from": "ZZZ", "to": "YYY", "carrier": "QR"}], + origin="ZZZ", + ) + with pytest.raises(FareLookupError, match="ZZZ"): + estimator.estimate_total(itin) + + +# ------------------------------------------------------------------ +# Fixture-based tests (DONE3, LONE3) +# ------------------------------------------------------------------ + + +class TestDONE3Fixture: + """DONE3 fixture: 3-continent business routing.""" + + def test_done3_base_fare_nonzero(self, estimator, done3_itinerary): + """DONE3 fixture should have nonzero base fare.""" + itin = Itinerary.model_validate(done3_itinerary) + result = estimator.estimate_total(itin) + assert result.base_fare_usd > 0 + + def test_done3_per_person_in_range(self, estimator, done3_itinerary): + """DONE3 per-person cost should be reasonable ($3,000-$6,000).""" + itin = Itinerary.model_validate(done3_itinerary) + result = estimator.estimate_total(itin) + assert 3_000 <= result.total_per_person_usd <= 6_000 + + def test_done3_cheaper_than_done4(self, estimator, done3_itinerary, v3_itinerary): + """DONE3 base fare should be less than DONE4.""" + done3 = Itinerary.model_validate(done3_itinerary) + done4 = Itinerary.model_validate(v3_itinerary) + # Both are ex-CAI, so compare base fares + fare3 = estimator.get_base_fare(done3.ticket.origin, done3.ticket.type) + fare4 = estimator.get_base_fare(done4.ticket.origin, done4.ticket.type) + assert fare3 < fare4 + + +class TestLONE3Fixture: + """LONE3 fixture: 3-continent economy routing.""" + + def test_lone3_base_fare_nonzero(self, estimator, lone3_itinerary): + """LONE3 fixture should have nonzero base fare.""" + itin = Itinerary.model_validate(lone3_itinerary) + result = estimator.estimate_total(itin) + assert result.base_fare_usd > 0 + + def test_lone3_cheaper_than_done3(self, estimator, lone3_itinerary): + """LONE3 should be cheaper than DONE3 for same continent count.""" + itin = Itinerary.model_validate(lone3_itinerary) + lone3_fare = estimator.get_base_fare(itin.ticket.origin, TicketType.LONE3) + done3_fare = estimator.get_base_fare(itin.ticket.origin, TicketType.DONE3) + assert lone3_fare < done3_fare diff --git a/tests/test_fare_comparison.py b/tests/test_fare_comparison.py index a8e1fbc..f22e148 100644 --- a/tests/test_fare_comparison.py +++ b/tests/test_fare_comparison.py @@ -101,11 +101,11 @@ def test_fare_lookup_unknown_origin(self): e = CostEstimator() assert e.get_base_fare("BOM", TicketType.DONE4) == 0.0 - def test_fare_lookup_unknown_ticket_type(self): + def test_fare_lookup_aone4_exists(self): e = CostEstimator() - # AONE4 doesn't exist in our data + # AONE4 now exists in fares.yaml result = e.get_base_fare("SYD", TicketType("AONE4")) - assert result == 0.0 + assert result > 0 def test_fare_lookup_case_insensitive(self): e = CostEstimator() diff --git a/tests/test_ntp.py b/tests/test_ntp.py index c0b69bd..4297c3a 100644 --- a/tests/test_ntp.py +++ b/tests/test_ntp.py @@ -96,6 +96,20 @@ def test_fj_atr72_earns_at_d_rate(self, calc): assert est.estimated_ntp == pytest.approx(162, abs=25) assert "ATR-72" in est.notes + def test_wy_ntp_positive_for_business(self, calc): + """WY MCT-LHR in D class should earn positive NTP (12.5% of ~3260 miles).""" + itin = _make_itinerary( + [ + {"from": "MCT", "to": "LHR", "carrier": "WY"}, + ] + ) + results = calc.calculate(itin, booking_class="D") + assert len(results) == 1 + est = results[0] + assert est.method == NTPMethod.DISTANCE + assert est.rate == 12.5 + assert est.estimated_ntp > 0 + # ------------------------------------------------------------------ # Revenue-based NTP tests diff --git a/tests/test_rules/test_carriers.py b/tests/test_rules/test_carriers.py index d6c268a..ac3ecea 100644 --- a/tests/test_rules/test_carriers.py +++ b/tests/test_rules/test_carriers.py @@ -76,6 +76,67 @@ def test_v3_matches_done4(self, v3_itinerary): assert len(passed_results) >= 0 # May pass or warn depending on continent resolution +class TestWYCarrier: + """Oman Air (WY) carrier data and eligibility.""" + + def test_wy_is_eligible(self): + """WY should be eligible for oneworld Explorer.""" + segs = [ + {"from": "MCT", "to": "LHR", "carrier": "WY"}, + {"from": "LHR", "to": "MCT", "carrier": "WY"}, + ] + itin = _make_itinerary(segs, origin="MCT") + ctx = build_context(itin) + results = EligibleCarrierRule().check(itin, ctx) + assert all(r.passed for r in results), "WY should be eligible" + + def test_wy_carrier_data_complete(self): + """WY should have all required carrier fields.""" + import yaml + from pathlib import Path + + carriers_path = Path(__file__).parent.parent.parent / "rtw" / "data" / "carriers.yaml" + with open(carriers_path) as f: + carriers = yaml.safe_load(f) + wy = carriers.get("WY", {}) + assert wy.get("name") == "Oman Air" + assert wy.get("alliance") == "oneworld" + assert wy.get("eligible") is True + assert wy.get("ntp_method") == "distance" + assert wy.get("rtw_booking_class") == "D" + assert wy.get("yq_tier") is not None + assert wy.get("yq_estimate_per_segment") is not None + + +class TestS7Carrier: + """S7 Airlines sanctions flag.""" + + def test_s7_is_ineligible(self): + """S7 should be ineligible (sanctions-suspended).""" + segs = [ + {"from": "OVB", "to": "SVO", "carrier": "S7"}, + {"from": "SVO", "to": "OVB", "carrier": "S7"}, + ] + itin = _make_itinerary(segs, origin="OVB") + ctx = build_context(itin) + results = EligibleCarrierRule().check(itin, ctx) + assert any(not r.passed for r in results), "S7 should fail eligibility" + + def test_s7_violation_mentions_sanctions(self): + """S7 violation message should mention sanctions.""" + segs = [ + {"from": "OVB", "to": "SVO", "carrier": "S7"}, + {"from": "SVO", "to": "OVB", "carrier": "S7"}, + ] + itin = _make_itinerary(segs, origin="OVB") + ctx = build_context(itin) + results = EligibleCarrierRule().check(itin, ctx) + failed = [r for r in results if not r.passed] + assert any("sanction" in r.message.lower() for r in failed), ( + "S7 violation should mention sanctions" + ) + + class TestTicketValidity: def test_v3_valid_duration(self, v3_itinerary): itin = Itinerary(**v3_itinerary)