|
1 | 1 | import os
|
2 |
| -from typing import Any, Callable, Literal, Optional |
| 2 | +from typing import Any, Callable, Literal, Optional, Union |
3 | 3 |
|
4 | 4 | from browserbase.types import SessionCreateParams as BrowserbaseSessionCreateParams
|
5 |
| -from pydantic import BaseModel, ConfigDict, Field |
| 5 | +from pydantic import BaseModel, ConfigDict, Field, field_validator |
6 | 6 |
|
7 | 7 | from stagehand.schemas import AvailableModel
|
8 | 8 |
|
@@ -71,7 +71,9 @@ class StagehandConfig(BaseModel):
|
71 | 71 | alias="domSettleTimeoutMs",
|
72 | 72 | description="Timeout for DOM to settle (in ms)",
|
73 | 73 | )
|
74 |
| - browserbase_session_create_params: Optional[BrowserbaseSessionCreateParams] = Field( |
| 74 | + browserbase_session_create_params: Optional[ |
| 75 | + Union[BrowserbaseSessionCreateParams, dict[str, Any]] |
| 76 | + ] = Field( |
75 | 77 | None,
|
76 | 78 | alias="browserbaseSessionCreateParams",
|
77 | 79 | description="Browserbase session create params",
|
@@ -118,6 +120,17 @@ class StagehandConfig(BaseModel):
|
118 | 120 |
|
119 | 121 | model_config = ConfigDict(populate_by_name=True)
|
120 | 122 |
|
| 123 | + @field_validator("browserbase_session_create_params", mode="before") |
| 124 | + @classmethod |
| 125 | + def validate_browserbase_params(cls, v, info): |
| 126 | + """Validate and convert browserbase session create params.""" |
| 127 | + if isinstance(v, dict) and "project_id" not in v: |
| 128 | + values = info.data |
| 129 | + project_id = values.get("project_id") or values.get("projectId") |
| 130 | + if project_id: |
| 131 | + v = {**v, "project_id": project_id} |
| 132 | + return v |
| 133 | + |
121 | 134 | def with_overrides(self, **overrides) -> "StagehandConfig":
|
122 | 135 | """
|
123 | 136 | Create a new config instance with the specified overrides.
|
|
0 commit comments