-
Notifications
You must be signed in to change notification settings - Fork 315
/
Copy pathtree_of_thought.py
60 lines (50 loc) · 2.08 KB
/
tree_of_thought.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from config import set_environment
from langchain.chains import LLMChain, SequentialChain
from langchain_core.prompts import PromptTemplate
from langchain_openai.chat_models import ChatOpenAI
set_environment()
solutions_template = """
Generate {num_solutions} distinct solutions for {problem}. Consider factors like {factors}.
Solutions:
"""
solutions_prompt = PromptTemplate(
template=solutions_template, input_variables=["problem", "factors", "num_solutions"]
)
evaluation_template = """
Evaluate each solution in {solutions} by analyzing pros, cons, feasibility,
and probability of success.
Evaluations:
"""
evaluation_prompt = PromptTemplate(template=evaluation_template, input_variables=["solutions"])
reasoning_template = """
For the most promising solutions in {evaluations}, explain scenarios, implementation strategies,
partnerships needed, and handling potential obstacles.
Enhanced Reasoning:
"""
reasoning_prompt = PromptTemplate(template=reasoning_template, input_variables=["evaluations"])
ranking_template = """
Based on the evaluations and reasoning, rank the solutions in {enhanced_reasoning} from
most to least promising.
Ranked Solutions:
"""
ranking_prompt = PromptTemplate(template=ranking_template, input_variables=["enhanced_reasoning"])
solutions_chain = LLMChain(llm=ChatOpenAI(), prompt=solutions_prompt, output_key="solutions")
evalutation_chain = LLMChain(llm=ChatOpenAI(), prompt=evaluation_prompt, output_key="evaluations")
reasoning_chain = LLMChain(
llm=ChatOpenAI(), prompt=reasoning_prompt, output_key="enhanced_reasoning"
)
ranking_chain = LLMChain(llm=ChatOpenAI(), prompt=ranking_prompt, output_key="ranked_solutions")
tot_chain = SequentialChain(
chains=[solutions_chain, evalutation_chain, reasoning_chain, ranking_chain],
input_variables=["problem", "factors", "num_solutions"],
output_variables=["ranked_solutions"],
)
print(
tot_chain.run(
problem="Prompt engineering",
factors="Requirements for high task performance, low token use, and few calls to the LLM",
num_solutions=3,
)
)
if __name__ == "__main__":
pass