-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathcli_request_scribe.py
127 lines (117 loc) · 6.29 KB
/
cli_request_scribe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import requests, json, os, time, argparse, base64
import yaml
import sys
from cli_logger import logger, set_logger_verbosity, quiesce_logger, test_logger
from PIL import Image
from io import BytesIO
from requests.exceptions import ConnectionError
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--api_key', type=str, action='store', required=False, help="The API Key to use to authenticate on the Horde. Get one in https://aihorde.net/register")
arg_parser.add_argument('-n', '--amount', action="store", required=False, type=int, help="The amount of images to generate with this prompt")
arg_parser.add_argument('-p','--prompt', action="store", required=False, type=str, help="The prompt with which to generate images")
arg_parser.add_argument('-c', '--max_context_length', action="store", required=False, type=int, help="The maximum amount of tokens to read from the prompt")
arg_parser.add_argument('-l', '--max_length', action="store", required=False, type=int, help="The maximum amount of tokens to generate")
arg_parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen")
arg_parser.add_argument('-q', '--quiet', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen")
arg_parser.add_argument('--horde', action="store", required=False, type=str, default="https://aihorde.net", help="Use a different horde")
arg_parser.add_argument('--trusted_workers', action="store_true", default=False, required=False, help="If true, the request will be sent only to trusted workers.")
arg_parser.add_argument('--dry_run', action="store_true", default=False, required=False, help="If true, The request will only print the amount of kudos the payload would spend, and exit.")
args = arg_parser.parse_args()
class RequestData(object):
def __init__(self):
self.client_agent = "cli_request_scribe.py:1.1.0:(discord)db0#1625"
self.api_key = "0000000000"
self.txtgen_params = {
"n": 1,
"max_context_length": 1024,
"max_length": 40,
}
self.submit_dict = {
"prompt": "a horde of cute kobolds furiously typing on typewriters",
"trusted_workers": False,
"models": [],
"dry_run": False
}
def get_submit_dict(self):
submit_dict = self.submit_dict.copy()
submit_dict["params"] = self.txtgen_params
return(submit_dict)
def load_request_data():
request_data = RequestData()
if os.path.exists("cliRequestsData_Scribe.yml"):
with open("cliRequestsData_Scribe.yml", "rt", encoding="utf-8", errors="ignore") as configfile:
config = yaml.safe_load(configfile)
for key, value in config.items():
setattr(request_data, key, value)
if args.api_key: request_data.api_key = args.api_key
if args.amount: request_data.txtgen_params["n"] = args.amount
if args.max_context_length: request_data.txtgen_params["max_context_length"] = args.max_context_length
if args.max_length: request_data.txtgen_params["max_length"] = args.max_length
if args.prompt: request_data.submit_dict["prompt"] = args.prompt
if args.trusted_workers: request_data.submit_dict["trusted_workers"] = args.trusted_workers
if args.dry_run: request_data.submit_dict["dry_run"] = args.dry_run
return(request_data)
@logger.catch(reraise=True)
def generate():
request_data = load_request_data()
# final_submit_dict["source_image"] = 'Test'
headers = {
"apikey": request_data.api_key,
"Client-Agent": request_data.client_agent,
}
# logger.debug(request_data.get_submit_dict())
submit_req = requests.post(f'{args.horde}/api/v2/generate/text/async', json = request_data.get_submit_dict(), headers = headers)
if submit_req.ok:
submit_results = submit_req.json()
logger.debug(submit_results)
req_id = submit_results.get('id')
if not req_id:
logger.message(submit_results)
return
is_done = False
retry = 0
cancelled = False
try:
while not is_done:
try:
chk_req = requests.get(f'{args.horde}/api/v2/generate/text/status/{req_id}')
if not chk_req.ok:
logger.error(chk_req.text)
return
chk_results = chk_req.json()
logger.info(chk_results)
is_done = chk_results['done']
time.sleep(0.8)
except ConnectionError as e:
retry += 1
logger.error(f"Error {e} when retrieving status. Retry {retry}/10")
if retry < 10:
time.sleep(1)
continue
raise
except KeyboardInterrupt:
logger.info(f"Cancelling {req_id}...")
cancelled = True
retrieve_req = requests.delete(f'{args.horde}/api/v2/generate/text/status/{req_id}')
if not cancelled:
retrieve_req = requests.get(f'{args.horde}/api/v2/generate/text/status/{req_id}')
if not retrieve_req.ok:
logger.error(retrieve_req.text)
return
results_json = retrieve_req.json()
# logger.debug(results_json)
if results_json['faulted']:
final_submit_dict = request_data.get_submit_dict()
logger.error(f"Something went wrong when generating the request. Please contact the horde administrator with your request details: {final_submit_dict}")
return
results = results_json['generations']
for iter in range(len(results)):
if len(results[iter]['text']) == 0:
logger.generation(f"{iter}: <This generation returned an empty string (EOS)>")
else:
logger.generation(f"{iter}: {results[iter]['text']}")
else:
logger.error(submit_req.text)
set_logger_verbosity(args.verbosity)
quiesce_logger(args.quiet)
generate()