Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for multiple accounts #104

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added __init__.py
Empty file.
23 changes: 14 additions & 9 deletions src/pychatgpt/classes/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
colorama.init(autoreset=True)


def token_expired() -> bool:
def token_expired(email: str) -> bool:
"""
Check if the creds have expired
returns:
Expand All @@ -41,7 +41,7 @@ def token_expired() -> bool:

with open(path, 'r') as f:
creds = json.load(f)
expires_at = float(creds['expires_at'])
expires_at = float(creds[email]['expires_at'])
if time.time() > expires_at + 3600:
return True
else:
Expand All @@ -52,7 +52,7 @@ def token_expired() -> bool:
return True


def get_access_token() -> Tuple[str or None, str or None]:
def get_access_token(email: str) -> Tuple[str or None, str or None]:
"""
Get the access token
returns:
Expand All @@ -65,7 +65,7 @@ def get_access_token() -> Tuple[str or None, str or None]:

with open(path, 'r') as f:
creds = json.load(f)
return creds['access_token'], creds['expires_at']
return creds[email]['access_token'], creds[email]['expires_at']
except FileNotFoundError:
return None, None

Expand Down Expand Up @@ -358,7 +358,7 @@ def _part_eight(self, old_state: str, new_state):
access_token = access_token.split('"')[0]
print(f"{Fore.GREEN}[OpenAI][8] {Fore.WHITE}Access Token: {Fore.GREEN}{access_token}")
# Save access_token
self.save_access_token(access_token=access_token)
self.save_access_token(email=self.email_address, access_token=access_token)
else:
print(f"{Fore.GREEN}[OpenAI][8][CRITICAL] {Fore.WHITE}Access Token: {Fore.RED}Not found"
f" Auth0 did not issue an access token.")
Expand Down Expand Up @@ -386,7 +386,7 @@ def part_nine(self):
json_response = response.json()
access_token = json_response['accessToken']
print(f"{Fore.GREEN}[OpenAI][9] {Fore.WHITE}Access Token: {Fore.GREEN}{access_token}")
self.save_access_token(access_token=access_token)
self.save_access_token(email=self.email_address, access_token=access_token)
else:
print(f"{Fore.GREEN}[OpenAI][9] {Fore.WHITE}Access Token: {Fore.RED}Not found, "
f"Please try again with a proxy (or use a new proxy if you are using one)")
Expand All @@ -395,7 +395,7 @@ def part_nine(self):
f"Please try again with a proxy (or use a new proxy if you are using one)")

@staticmethod
def save_access_token(access_token: str, expiry: int or None = None):
def save_access_token(email: str, access_token: str, expiry: int or None = None):
"""
Save access_token and an hour from now on CHATGPT_ACCESS_TOKEN CHATGPT_ACCESS_TOKEN_EXPIRY environment variables
:param expiry:
Expand All @@ -409,8 +409,13 @@ def save_access_token(access_token: str, expiry: int or None = None):
# Get path using os, it's in ./classes/auth.json
path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path, "auth.json")
with open(path, "w") as f:
f.write(json.dumps({"access_token": access_token, "expires_at": expiry}))
with open(path, "r") as f:
content = f.read()
account_dict = {} if content == '' else json.loads(content)
f.close()
with open(path, "w") as file:
account_dict[email] = {"access_token": access_token, "expires_at": expiry}
file.write(json.dumps(account_dict))

print(f"{Fore.GREEN}[OpenAI][8] {Fore.WHITE}Saved access token")
except Exception as e:
Expand Down
12 changes: 12 additions & 0 deletions src/pychatgpt/classes/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
import json

# Get path using os, it's in ./classes/auth.json
path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path, "auth.json")
with open(path, "r") as f:
content = f.read()
account_dict = {} if content == '' else json.loads(content)
print(account_dict)
# account_dict[email] = {"access_token": access_token, "expires_at": expiry}
# f.write(json.dumps(account_dict))
14 changes: 7 additions & 7 deletions src/pychatgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ def _setup(self):
raise Exceptions.PyChatGPTException("When resuming a chat, there was an issue reading id_log, make sure that it is formatted correctly.")

# Check for access_token & access_token_expiry in env
if OpenAI.token_expired():
if OpenAI.token_expired(self.email):
self.log(f"{Fore.RED}>> Access Token missing or expired."
f" {Fore.GREEN}Attempting to create them...")
self._create_access_token()
else:
access_token, expiry = OpenAI.get_access_token()
access_token, expiry = OpenAI.get_access_token(self.email)
self.__auth_access_token = access_token
self.__auth_access_token_expiry = expiry

Expand All @@ -157,7 +157,7 @@ def _create_access_token(self) -> bool:
openai_auth.create_token()

# If after creating the token, it's still expired, then something went wrong.
is_still_expired = OpenAI.token_expired()
is_still_expired = OpenAI.token_expired(self.email)
if is_still_expired:
self.log(f"{Fore.RED}>> Failed to create access token.")
return False
Expand Down Expand Up @@ -185,7 +185,7 @@ def ask(self, prompt: str,
raise Exceptions.PyChatGPTException("Cannot enter a non-queue object as the response queue for threads.")

# Check if the access token is expired
if OpenAI.token_expired():
if OpenAI.token_expired(self.email):
self.log(f"{Fore.RED}>> Your access token is expired. {Fore.GREEN}Attempting to recreate it...")
did_create = self._create_access_token()
if did_create:
Expand All @@ -195,7 +195,7 @@ def ask(self, prompt: str,
raise Exceptions.PyChatGPTException("Failed to recreate access token.")

# Get access token
access_token = OpenAI.get_access_token()
access_token = OpenAI.get_access_token(self.email)

# Set conversation IDs if supplied
if previous_convo_id is not None:
Expand Down Expand Up @@ -253,7 +253,7 @@ def cli_chat(self, rep_queue: Queue or None = None):
raise Exceptions.PyChatGPTException("Cannot enter a non-queue object as the response queue for threads.")

# Check if the access token is expired
if OpenAI.token_expired():
if OpenAI.token_expired(self.email):
self.log(f"{Fore.RED}>> Your access token is expired. {Fore.GREEN}Attempting to recreate it...")
did_create = self._create_access_token()
if did_create:
Expand All @@ -268,7 +268,7 @@ def cli_chat(self, rep_queue: Queue or None = None):


# Get access token
access_token = OpenAI.get_access_token()
access_token = OpenAI.get_access_token(self.email)

while True:
try:
Expand Down