1 回答

TA贡献1887条经验 获得超5个赞
假设graph_nodes是字典,每个句子必须以概率 1.0 的符号开头,<s>所有句子都必须以特殊符号结尾</s>。为了避免对假设进行排序,我将它们放在堆中,因此添加元素是恒定的。
import heapq
beam = [(1.0, ["<s>"])]
for _ in range(sen_length):
new_beam = []
for score, hypothesis in beam:
hypothesis_end = hypothesis[-1]
# finished hypothesis will return to the beam and compete with new ones
if hypothesis_end == "</s>":
heapq.heappush(new_beam, (score, hypothesis))
if len(new_beam) > beam_width:
heapq.heappop(new_beam)
# expand unfinished hypothesis
for possible_continuation in graph_nodes[hypothesis_end]:
continuation_score = score * get_prob(hypothesis_end, possible_continuation)
heapq.heappush(
new_beam, (continuation_score, hypothesis + [possible_continuation])
if len(new_beam) > beam_width:
heapq.heappop(new_beam)
beam = new_beam
如果您的假设可以有不同的长度,您应该考虑一些长度归一化(例如,概率的几何平均值)。此外,乘法概率在数值上可能并不总是稳定的,因此您可能希望使用对数和。
添加回答
举报