Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tabpfn_client.constants import CACHE_DIR
from tabpfn_client.browser_auth import BrowserAuthHandler
from tabpfn_client.tabpfn_common_utils.utils import Singleton
from tabpfn_client.tabpfn_common_utils.usage_analytics import AnalyticsHttpClient

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -157,15 +158,20 @@ class ServiceClient(Singleton):
httpx_timeout_s = (
4 * 5 * 60 + 15 # temporary workaround for slow computation on server side
)
httpx_client = httpx.Client(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests missing.

httpx_client = AnalyticsHttpClient(
base_url=base_url,
timeout=httpx_timeout_s,
headers={"client-version": get_client_version()},
module_name="tabpfn_client",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure there is no way of getting the module name programmatically?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least not robustly, in my opinion.
My rationale: tabpfn-extensions or wrappers built by users can create arbitrarily deep call stack. So I thought it'd be easier to let the wrapper, e.g. tabpfn-extensions, set the module name explicitly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I was hoping there is a better way to get the module name than iterating the call stack.

)

_access_token = None
dataset_uid_cache_manager = DatasetUIDCacheManager()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be private btw.


@classmethod
def set_module_name(cls, module_name: str) -> None:
cls.httpx_client.set_module_name(module_name)

@classmethod
def get_access_token(cls):
return cls._access_token
Expand Down
7 changes: 6 additions & 1 deletion tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ def __new__(cls, *args, **kwargs):
use_server = False


def init(use_server=True):
def init(
use_server=True,
module_name="tabpfn_client",
):
# initialize config
Config.use_server = use_server

Expand All @@ -30,6 +33,8 @@ def init(use_server=True):
return

if use_server:
ServiceClient.set_module_name(module_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are not expecting this to change during service run. It should be ONLY in the constructor of ServiceClient and then stored as a private member.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ServiceClient is a singleton class, hence we don't have an constructor as per say.
I'm also wondering if there's a better design for classes like ServiceClient

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be. Singletons are to be handled with cautions, because as for global variables, they can be changed in multiple places and thus introduce side-effects. Also they are harder to test, because dependency injection does not work.

I'd need to think more about this here.

But one way would be to create a client instance. And then this instance needs to be passed when calls are made.


# check connection to server
if not UserAuthenticationClient.is_accessible_connection():
raise RuntimeError(
Expand Down
2 changes: 1 addition & 1 deletion tabpfn_client/tabpfn_common_utils