目标

为了帮助小伙伴们更好的发挥chain 的作用以及增加chain 的灵活性,有一些函数可以参与到chain的各环节,这个文档来给大家介绍一下~

  • RunnableParallel
  • RunnablePassthrough
  • RunnableLambda
  • RunnableBranch

RunnableParallel

RunnableParallel 在将一个 Runnable 的输出调整为符合序列中下一个 Runnable 的输入格式时非常有用。在这里,提示(prompt)的输入预期是一个包含“context”和“question”键的映射(map)。用户输入仅仅是问题(question)。我们需要使用我们的检索器(retriever)获取上下文(context),并将用户输入作为“question”键下的值传递过去。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 提示词
template = """Answer the question based only on the following context:
{context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| model
| output_parser
)
res = chain.invoke("how can langsmith help with testing?")
print(res)

# 问题 下面这样可以么?
res = chain.invoke({"question": "how can langsmith help with testing?"})

不行,相当于:"question":{"question": "how can langsmith help with testing?"}

这里

1
2
3
4
5
6
7
8
# 可以是这样
{"context": retriever, "question": RunnablePassthrough()}

# 可以是这样
RunnableParallel({"context": retriever, "question": RunnablePassthrough()})

# 可以是这样
RunnableParallel(context=retriever, question=RunnablePassthrough())

重点: 还可以通过itemgetter 获得环境中的Key 值

1
2
3
4
5
6
7
8
9
10
11
12
chain = (
# 原来是这样 {"context": retriever, "question": RunnablePassthrough()}
{
"context": itemgetter("question") | retriever, <----------
"question": itemgetter("question"), <---------
}
| prompt
| model
| StrOutputParser()
)
res = chain.invoke("how can langsmith help with testing?")
print(res)

并发执行chain

1
2
3
4
5
6
7
8
joke_chain = ChatPromptTemplate.from_template("告诉我一个关于 {topic}的笑话") | model
topic_chain = (
ChatPromptTemplate.from_template("告诉我一个关于{topic}的话题") | model
)

map_chain = RunnableParallel(joke=joke_chain, topic=topic_chain)

res = map_chain.invoke({"topic": "冰淇淋"})

RunnablePassthrough

RunnablePassthrough 允许您原封不动地传递输入,或者添加额外的键。这通常与 RunnableParallel 一起使用,以便在映射中为数据分配一个新的键。 当单独调用 RunnablePassthrough() 时,它只会简单地接收输入并直接传递出去。 当调用 RunnablePassthrough 的 assign 方法(RunnablePassthrough.assign(…))时,它会接收输入,并将传递给 assign 函数的额外参数添加到输入中。

1
2
3
4
5
6
7
8
9
10
11
12
# RunnablePassthrough
runnable = RunnableParallel(
passed=RunnablePassthrough(),
extra=RunnablePassthrough.assign(mult=lambda x: x["num"] * 3),
modified=lambda x: x["num"] + 1,
)

runnable.invoke({"num": 1})

------
# 结果
{'extra': {'mult': 3, 'num': 1}, 'modified': 2, 'passed': {'num': 1}}

RunnableLambda

您可以在管道中使用任意函数。 请注意,所有输入到这些函数的参数需要是单个参数。如果您有一个接受多个参数的函数,您应该编写一个包装器,该包装器接受单个输入并将其解包为多个参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def format_docs_into_text(docs):
doc_size = len(docs)
print('收到 {} 个文档'.format(doc_size))
return docs
chain = (
{
"context": itemgetter("question") | retriever | RunnableLambda(format_docs_into_text),
"question": itemgetter("question"),
}
| prompt
| model
| StrOutputParser()
)
res = chain.invoke({"question": "how can langsmith help with testing?"})
print(res)

2个函数都能正常调用,差异在哪?

函数直接返回Document 对象:

1
2
3
4
def format_docs_into_text(docs):
doc_size = len(docs)
print('收到 {} 个文档'.format(doc_size))
return docs

函数直接返回字符串:

1
2
3
4
def format_docs_into_text2(docs):
doc_size = len(docs)
print('收到 {} 个文档'.format(doc_size))
return "\n\n".join(doc.page_content for doc in docs)

生成的提示词不同

直接返回docs 的情况

[HumanMessage(content=”Answer the question based only on the following context:\n [Document(page_content=’LangSmith User Guide | \uf8ffü¶úÔ∏è\uf8ffüõ†Ô …

返回字符串的情况

[HumanMessage(content=’Answer the question based only on the following context:\n LangSmith User Guide | \uf8ffü¶úÔ∏è\uf8ffüõ†Ô∏è LangSmith\n\nLangSmith User Guide | \uf8ffü¶úÔ∏è\uf8ffüõ


RunnableBranch

1个决策chain + 3个分支 chain

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# 三个分支: langchain \ 百度 \  缺省

langchain_chain = (
PromptTemplate.from_template(
"""你是langchain 专家,回答下面的问题:
问题: {question}
回答:"""
)
| model
)
baidu_chain = (
PromptTemplate.from_template(
"""你是百度AI专家,回答下面的问题:
问题: {question}
回答:"""
)
| model
)
general_chain = (
PromptTemplate.from_template(
"""
回答下面的问题:{question}
"""
)
| model
)

# 选择分支
chain = (
PromptTemplate.from_template(
"""基于用户问题,选择这个问题是属于 `LangChain`, `百度`, or `其他`.
不要返回多余的词。
<question>
{question}
</question>
分类:"""
)
| model
| StrOutputParser()
)

选择分支

1
2
3
4
5
6
7
8
9
10
    def route(info):
if "百度" in info["topic"].lower():
return baidu_chain
elif "langchain" in info["topic"].lower():
return langchain_chain
else:
return general_chain
# ---------------
# 结果
{'question': '我如何使用百度?', 'topic': '百度'}

整个链路

1
2
3
4
5
6
full_chain = {"topic": chain, "question": lambda x: x["question"]} | RunnableLambda(
route
)

res = full_chain.invoke({"question": "我如何使用百度?"})
print(res)

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
import uuid

from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain_community.embeddings import QianfanEmbeddingsEndpoint
from langchain_community.vectorstores.chroma import Chroma
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableParallel, RunnableLambda

from operator import itemgetter



if __name__ == '__main__':

os.environ["QIANFAN_ACCESS_KEY"] = os.getenv('MY_QIANFAN_ACCESS_KEY')
os.environ["QIANFAN_SECRET_KEY"] = os.getenv('MY_QIANFAN_SECRET_KEY')

unique_id = uuid.uuid4().hex[0:8]
os.environ["LANGCHAIN_PROJECT"] = f" [返回docs]轨迹 - {unique_id}"
os.environ["LANGCHAIN_TRACING_V2"] = 'true'
os.environ["LANGCHAIN_API_KEY"] = os.getenv('MY_LANGCHAIN_API_KEY')

model = QianfanChatEndpoint(
model="ERNIE-Bot-4"
)

# 检索chain

# 千帆嵌入模型
embeddings_model = QianfanEmbeddingsEndpoint(model="bge_large_en", endpoint="bge_large_en")
# 载入数据库
vector_store = Chroma(persist_directory="D:\\LLM\\my_projects\\chroma_db", embedding_function=embeddings_model)

# 创建检索器
retriever = vector_store.as_retriever()

# 提示词
template = """Answer the question based only on the following context:
{context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

chain = (
# 原来是这样 {"context": retriever, "question": RunnablePassthrough()}
{
"context": itemgetter("question") | retriever,
"question": itemgetter("question"),
}
| prompt
| model
| StrOutputParser()
)
# res = chain.invoke("how can langsmith help with testing?")
# print(res)


# 并行执行
joke_chain = ChatPromptTemplate.from_template("告诉我一个关于 {topic}的笑话") | model
topic_chain = (
ChatPromptTemplate.from_template("告诉我一个关于{topic}的话题") | model
)

map_chain = RunnableParallel(joke=joke_chain, topic=topic_chain)

# res = map_chain.invoke({"topic": "冰淇淋"})

# RunnablePassthrough
runnable = RunnableParallel(
passed=RunnablePassthrough(),
extra=RunnablePassthrough.assign(mult=lambda x: x["num"] * 3),
modified=lambda x: x["num"] + 1,
)

# res = runnable.invoke({"num": 1})

#RunnableLambda
def format_docs_into_text(docs):
doc_size = len(docs)
print('收到 {} 个文档'.format(doc_size))
return docs

def format_docs_into_text2(docs):
doc_size = len(docs)
print('收到 {} 个文档'.format(doc_size))

return "\n\n".join(doc.page_content for doc in docs)

chain = (
{
"context": itemgetter("question") | retriever | RunnableLambda(format_docs_into_text),
"question": itemgetter("question"),
}
| prompt
| model
| StrOutputParser()
)

chain_test = (
{
"context": itemgetter("question") | retriever | RunnableLambda(format_docs_into_text),
"question": itemgetter("question"),
}
| prompt
)
res = chain_test.invoke({"question": "how can langsmith help with testing?"})

chain_test_2 = (
{
"context": itemgetter("question") | retriever | RunnableLambda(format_docs_into_text2),
"question": itemgetter("question"),
}
| prompt
)
res = chain_test_2.invoke({"question": "how can langsmith help with testing?"})


res = chain.invoke({"question": "how can langsmith help with testing?"})
print(res)

# RunnableBranch

# 三个分支: langchain \ 百度 \ 缺省

langchain_chain = (
PromptTemplate.from_template(
"""你是langchain 专家,回答下面的问题:
问题: {question}
回答:"""
)
| model
)
baidu_chain = (
PromptTemplate.from_template(
"""你是百度AI专家,回答下面的问题:
问题: {question}
回答:"""
)
| model
)
general_chain = (
PromptTemplate.from_template(
"""
回答下面的问题:{question}
"""
)
| model
)

# 选择分支
chain = (
PromptTemplate.from_template(
"""基于用户问题,选择这个问题是属于 `LangChain`, `百度`, or `其他`.
不要返回多余的词。
<question>
{question}
</question>
分类:"""
)
| model
| StrOutputParser()
)

# 情况判断
def route(info):
if "百度" in info["topic"].lower():
return baidu_chain
elif "langchain" in info["topic"].lower():
return langchain_chain
else:
return general_chain


full_chain = {"topic": chain, "question": lambda x: x["question"]} | RunnableLambda(
route
)

res = full_chain.invoke({"question": "我如何使用百度?"})
print(res)