forked from cloveranon/Clover-Edition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
storymanager.py
120 lines (104 loc) · 4.66 KB
/
storymanager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import json
import re
from getconfig import settings
from utils import output, format_result, format_input, get_similarity
class Story:
# the initial prompt is very special.
# We want it to be permanently in the AI's limited memory (as well as possibly other strings of text.)
def __init__(self, generator, context='', memory=None):
if memory is None:
memory = []
self.generator = generator
self.context = context
self.memory = memory
self.actions = []
self.results = []
self.savefile = ""
def act(self, action, record=True, format=True):
assert (self.context.strip() + action.strip())
assert (settings.getint('top-keks') is not None)
result = self.generator.generate(
self.get_story() + action,
self.context + ' '.join(self.memory),
temperature=settings.getfloat('temp'),
top_p=settings.getfloat('top-p'),
top_k=settings.getint('top-keks'),
repetition_penalty=settings.getfloat('rep-pen'))
if record:
self.actions.append(format_input(action))
self.results.append(format_input(result))
return format_result(result) if format else result
def print_action_result(self, i, wrap=True, color=True):
col1 = 'user-text' if color else None
col2 = 'ai-text' if color else None
if i == 0 or len(self.actions) == 1:
start = format_result(self.context + ' ' + self.actions[0])
result = format_result(self.results[0])
is_start_end = re.match(r"[.!?]\s*$", start) # if start ends logically
is_result_continue = re.match(r"^\s*[a-z.!?,\"]", result) # if result is a continuation
sep = ' ' if not is_start_end and is_result_continue else '\n'
if not self.actions[0]:
output(self.context, col1, self.results[0], col2, sep=sep)
else:
output(self.context, col1)
output(self.actions[0], col1, self.results[0], col2, sep=sep)
else:
if i < len(self.actions) and self.actions[i].strip() != "":
caret = "> " if re.match(r"^ *you +", self.actions[i], flags=re.I) else ""
output(format_result(caret + self.actions[i]), col1, wrap=wrap)
if i < len(self.results) and self.results[i].strip() != "":
output(format_result(self.results[i]), col2, wrap=wrap)
def print_story(self, wrap=True, color=True):
for i in range(0, max(len(self.actions), len(self.results))):
self.print_action_result(i, wrap=wrap, color=color)
def print_last(self, wrap=True, color=True):
self.print_action_result(-1, wrap=wrap, color=color)
def get_story(self):
lines = [val for pair in zip(self.actions, self.results) for val in pair]
return '\n\n'.join(lines)
def revert(self):
self.actions = self.actions[:-1]
self.results = self.results[:-1]
def get_suggestion(self):
return re.sub('\n.*', '',
self.generator.generate_raw(
self.get_story() + "\n\n> You",
self.context,
temperature=settings.getfloat('action-temp'),
top_p=settings.getfloat('top-p'),
top_k=settings.getint('top-keks'),
repetition_penalty=1))
def __str__(self):
return self.context + ' ' + self.get_story()
def to_dict(self):
res = {}
res["temp"] = settings.getfloat('temp')
res["top-p"] = settings.getfloat("top-p")
res["top-keks"] = settings.getint("top-keks")
res["rep-pen"] = settings.getfloat("rep-pen")
res["context"] = self.context
res["memory"] = self.memory
res["actions"] = self.actions
res["results"] = self.results
return res
def from_dict(self, d):
settings["temp"] = str(d["temp"])
settings["top-p"] = str(d["top-p"])
settings["top-keks"] = str(d["top-keks"])
settings["rep-pen"] = str(d["rep-pen"])
self.context = d["context"]
self.memory = d["memory"]
self.actions = d["actions"]
self.results = d["results"]
def to_json(self):
return json.dumps(self.to_dict())
def from_json(self, j):
self.from_dict(json.loads(j))
def is_looping(self, threshold=0.9):
if len(self.results) >= 2:
similarity = get_similarity(self.results[-1], self.results[-2])
if similarity > threshold:
return True
return False
# def save()
# file=Path('saves', self.filename)