Skip to content

Commit e6b4832

Browse files
authored
only close session if owner (#13)
1 parent 8542332 commit e6b4832

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

mystbin/client.py

+13
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030

3131
if TYPE_CHECKING:
3232
import datetime
33+
from types import TracebackType
3334

3435
from aiohttp import ClientSession
36+
from typing_extensions import Self
3537

3638
__all__ = ("Client",)
3739

@@ -42,6 +44,17 @@ class Client:
4244
def __init__(self, *, token: str | None = None, session: ClientSession | None = None) -> None:
4345
self.http: HTTPClient = HTTPClient(token=token, session=session)
4446

47+
async def __aenter__(self) -> Self:
48+
return self
49+
50+
async def __aexit__(
51+
self,
52+
exc_cls: type[BaseException] | None,
53+
exc_value: BaseException | None,
54+
traceback: TracebackType | None
55+
) -> None:
56+
await self.close()
57+
4558
async def close(self) -> None:
4659
"""|coro|
4760

mystbin/http.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def __init__(self, verb: SupportedHTTPVerb, path: str, **params: Any) -> None:
126126
class HTTPClient:
127127
__slots__ = (
128128
"_session",
129+
"_owns_session",
129130
"_async",
130131
"_token",
131132
"_locks",
@@ -135,16 +136,18 @@ class HTTPClient:
135136
def __init__(self, *, token: str | None, session: aiohttp.ClientSession | None = None) -> None:
136137
self._token: str | None = token
137138
self._session: aiohttp.ClientSession | None = session
139+
self._owns_session: bool = False
138140
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
139141
user_agent = "mystbin.py (https://github.com/PythonistaGuild/mystbin.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}"
140142
self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__)
141143

142144
async def close(self) -> None:
143-
if self._session:
145+
if self._session and self._owns_session:
144146
await self._session.close()
145147

146148
async def _generate_session(self) -> aiohttp.ClientSession:
147149
self._session = aiohttp.ClientSession()
150+
self._owns_session = True
148151
return self._session
149152

150153
async def request(self, route: Route, **kwargs: Any) -> Any:

0 commit comments

Comments
 (0)