为了账号安全,请及时绑定邮箱和手机立即绑定

利用RNN和LSTM生成小说题记

一、选取素材

  • 语料格式

  • 题记:此情可待成追忆,只是当时已惘然。

二、开发环境

三、实战代码

#!/bash/bin# -*-coding:utf-8-*-import sysimport osimport numpy as npimport collectionsimport tensorflow as tfimport tensorflow.contrib.rnn as rnnimport tensorflow.contrib.legacy_seq2seq as seq2seq

BEGIN_CHAR = '^'END_CHAR = '$'UNKNOWN_CHAR = '*'MAX_LENGTH = 100MIN_LENGTH = 10max_words = 3000epochs = 50# 语料poetry_file = 'story.txt'# 模型文件存放位置save_dir = 'model'class Data:
    def __init__(self):
        self.batch_size = 64
        self.poetry_file = poetry_file
        self.load()
        self.create_batches()    def load(self):
        def handle(line):
            if len(line) > MAX_LENGTH:
                index_end = line.rfind('。', 0, MAX_LENGTH)
                index_end = index_end if index_end > 0 else MAX_LENGTH
                line = line[:index_end + 1]            return BEGIN_CHAR + line + END_CHAR

        self.poetrys = [line.strip().replace(' ', '').split(':')[1] for line in
                        open(self.poetry_file, encoding='utf-8')]
        self.poetrys = [handle(line) for line in self.poetrys if len(line) > MIN_LENGTH]        # 所有字
        words = []        for poetry in self.poetrys:
            words += [word for word in poetry]
        counter = collections.Counter(words)
        count_pairs = sorted(counter.items(), key=lambda x: -x[1])
        words, _ = zip(*count_pairs)        # 取出现频率最高的词的数量组成字典,不在字典中的字用'*'代替
        words_size = min(max_words, len(words))
        self.words = words[:words_size] + (UNKNOWN_CHAR,)
        self.words_size = len(self.words)        # 字映射成id
        self.char2id_dict = {w: i for i, w in enumerate(self.words)}
        self.id2char_dict = {i: w for i, w in enumerate(self.words)}
        self.unknow_char = self.char2id_dict.get(UNKNOWN_CHAR)
        self.char2id = lambda char: self.char2id_dict.get(char, self.unknow_char)
        self.id2char = lambda num: self.id2char_dict.get(num)
        self.poetrys = sorted(self.poetrys, key=lambda line: len(line))
        self.poetrys_vector = [list(map(self.char2id, poetry)) for poetry in self.poetrys]    def create_batches(self):
        self.n_size = len(self.poetrys_vector) // self.batch_size
        self.poetrys_vector = self.poetrys_vector[:self.n_size * self.batch_size]
        self.x_batches = []
        self.y_batches = []        for i in range(self.n_size):
            batches = self.poetrys_vector[i * self.batch_size: (i + 1) * self.batch_size]
            length = max(map(len, batches))            for row in range(self.batch_size):                if len(batches[row]) < length:
                    r = length - len(batches[row])
                    batches[row][len(batches[row]): length] = [self.unknow_char] * r
            xdata = np.array(batches)
            ydata = np.copy(xdata)
            ydata[:, :-1] = xdata[:, 1:]
            self.x_batches.append(xdata)
            self.y_batches.append(ydata)class Model:
    def __init__(self, data, model='lstm', infer=False):
        self.rnn_size = 128
        self.n_layers = 2

        if infer:
            self.batch_size = 1
        else:
            self.batch_size = data.batch_size        if model == 'rnn':
            cell_rnn = rnn.BasicRNNCell        elif model == 'gru':
            cell_rnn = rnn.GRUCell        elif model == 'lstm':
            cell_rnn = rnn.BasicLSTMCell

        cell = cell_rnn(self.rnn_size, state_is_tuple=False)
        self.cell = rnn.MultiRNNCell([cell] * self.n_layers, state_is_tuple=False)

        self.x_tf = tf.placeholder(tf.int32, [self.batch_size, None])
        self.y_tf = tf.placeholder(tf.int32, [self.batch_size, None])

        self.initial_state = self.cell.zero_state(self.batch_size, tf.float32)        with tf.variable_scope('rnnlm'):
            softmax_w = tf.get_variable("softmax_w", [self.rnn_size, data.words_size])
            softmax_b = tf.get_variable("softmax_b", [data.words_size])            with tf.device("/cpu:0"):
                embedding = tf.get_variable(                    "embedding", [data.words_size, self.rnn_size])
                inputs = tf.nn.embedding_lookup(embedding, self.x_tf)

        outputs, final_state = tf.nn.dynamic_rnn(
            self.cell, inputs, initial_state=self.initial_state, scope='rnnlm')

        self.output = tf.reshape(outputs, [-1, self.rnn_size])
        self.logits = tf.matmul(self.output, softmax_w) + softmax_b
        self.probs = tf.nn.softmax(self.logits)
        self.final_state = final_state
        pred = tf.reshape(self.y_tf, [-1])        # seq2seq
        loss = seq2seq.sequence_loss_by_example([self.logits],
                                                [pred],
                                                [tf.ones_like(pred, dtype=tf.float32)], )

        self.cost = tf.reduce_mean(loss)
        self.learning_rate = tf.Variable(0.0, trainable=False)
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), 5)

        optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))def train(data, model):
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        n = 0
        for epoch in range(epochs):
            sess.run(tf.assign(model.learning_rate, 0.002 * (0.97 ** epoch)))
            pointer = 0
            for batche in range(data.n_size):
                n += 1
                feed_dict = {model.x_tf: data.x_batches[pointer], model.y_tf: data.y_batches[pointer]}
                pointer += 1
                train_loss, _, _ = sess.run([model.cost, model.final_state, model.train_op], feed_dict=feed_dict)
                sys.stdout.write('\r')
                info = "{}/{} (epoch {}) | train_loss {:.3f}" \
                    .format(epoch * data.n_size + batche,
                            epochs * data.n_size, epoch, train_loss)
                sys.stdout.write(info)
                sys.stdout.flush()                # save
                if (epoch * data.n_size + batche) % 1000 == 0 \                        or (epoch == epochs - 1 and batche == data.n_size - 1):
                    checkpoint_path = os.path.join(save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=n)
                    sys.stdout.write('\n')
                    print("model saved to {}".format(checkpoint_path))
            sys.stdout.write('\n')def sample(data, model, head=u''):
    def to_word(weights):
        t = np.cumsum(weights)
        s = np.sum(weights)
        sa = int(np.searchsorted(t, np.random.rand(1) * s))        return data.id2char(sa)    for word in head:        if word not in data.words:            return u'{} 不在字典中'.format(word)    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables())
        model_file = tf.train.latest_checkpoint(save_dir)
        saver.restore(sess, model_file)        if head:
            print('生成题记 ---> ', head)
            poem = BEGIN_CHAR            for head_word in head:
                poem += head_word
                x = np.array([list(map(data.char2id, poem))])
                state = sess.run(model.cell.zero_state(1, tf.float32))
                feed_dict = {model.x_tf: x, model.initial_state: state}
                [probs, state] = sess.run([model.probs, model.final_state], feed_dict)
                word = to_word(probs[-1])                while word != u',' and word != u'。':
                    poem += word
                    x = np.zeros((1, 1))
                    x[0, 0] = data.char2id(word)
                    [probs, state] = sess.run([model.probs, model.final_state],
                                              {model.x_tf: x, model.initial_state: state})
                    word = to_word(probs[-1])
                poem += word            return poem[1:]        else:
            poem = ''
            head = BEGIN_CHAR
            x = np.array([list(map(data.char2id, head))])
            state = sess.run(model.cell.zero_state(1, tf.float32))
            feed_dict = {model.x_tf: x, model.initial_state: state}
            [probs, state] = sess.run([model.probs, model.final_state], feed_dict)
            word = to_word(probs[-1])            while word != END_CHAR:
                poem += word
                x = np.zeros((1, 1))
                x[0, 0] = data.char2id(word)
                [probs, state] = sess.run([model.probs, model.final_state],
                                          {model.x_tf: x, model.initial_state: state})
                word = to_word(probs[-1])            return poemif __name__ == '__main__':    # 训练模型
    data = Data()
    model = Model(data=data, infer=False)
    print(train(data, model))    # 生成题记
    # data = Data()
    # model = Model(data=data, infer=True)
    # print(sample(data, model, head='我为秋香'))
输出
生成题记 --->  我为秋香
我罢性不行,为德劝仙兴。秋风暝冰始,香巢深器酒。

输出



作者:_两只橙_
链接:https://www.jianshu.com/p/2b3b253adb00


点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消