1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
| class LSTMPoetryModel(object): def __init__(self, config): self.model = None self.do_train = True self.loaded_model = True self.config = config
self.word2idx_dic, self.idx2word, self.words, self.files_content = preprocess_data(self.config) self.poems = self.files_content.split(']') self.poems_num = len(self.poems) if os.path.exists(self.config.weight_file) and self.loaded_model: self.model = load_model(self.config.weight_file) else: self.train()
def build_model(self): '''LSTM模型构建''' print('模型构建中...')
input_tensor = Input(shape=(self.config.max_len, len(self.words))) lstm = LSTM(512, return_sequences=True)(input_tensor) dropout = Dropout(0.6)(lstm) lstm = LSTM(256)(dropout) dropout = Dropout(0.6)(lstm) dense = Dense(len(self.words), activation='softmax')(dropout) self.model = Model(inputs=input_tensor, outputs=dense) optimizer = Adam(lr=self.config.learning_rate) self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
def sample(self, preds, temperature=1.0): ''' temperature可以控制生成诗的创作自由约束度 当temperature<1.0时,模型会做一些随机探索,输出相对比较新的内容 当temperature>1.0时,模型预估方式偏保守 在训练的过程中可以看到temperature不同,结果也不同 就是一个概率分布变换的问题,保守的时候概率大的值变得更大,选择的可能性也更大 ''' preds = np.asarray(preds).astype('float64') exp_preds = np.power(preds,1./temperature) preds = exp_preds / np.sum(exp_preds) prob = np.random.choice(range(len(preds)),1,p=preds) return int(prob.squeeze()) def generate_sample_result(self, epoch, logs): '''训练过程中,每5个epoch打印出当前的学习情况''' if epoch % 5 != 0: return with open('out/out.txt', 'a',encoding='utf-8') as f: f.write('==================第{}轮=====================\n'.format(epoch)) print("\n==================第{}轮=====================".format(epoch)) for diversity in [0.7, 1.0, 1.3]: print("------------设定诗词创作自由度约束参数为{}--------------".format(diversity)) generate = self.predict_random(temperature=diversity) print(generate) with open('out/out.txt', 'a',encoding='utf-8') as f: f.write(generate+'\n') def predict_random(self,temperature = 1): '''预估模式1:随机从库中选取一句开头的诗句,生成五言绝句''' if not self.model: print('没有预训练模型可用于加载!') return index = random.randint(0, self.poems_num) sentence = self.poems[index][: self.config.max_len] generate = self.predict_sen(sentence,temperature=temperature) return generate def predict_first(self, char,temperature =1): '''预估模式2:根据给出的首个文字,生成五言绝句''' if not self.model: print('没有预训练模型可用于加载!') return index = random.randint(0, self.poems_num) sentence = self.poems[index][1-self.config.max_len:] + char generate = str(char) generate += self._preds(sentence,length=23,temperature=temperature) return generate def predict_sen(self, text,temperature =1): '''预估模式3:根据给出的前max_len个字,生成诗句''' '''此例中,即根据给出的第一句诗句(含逗号),来生成古诗''' if not self.model: return max_len = self.config.max_len if len(text)<max_len: print('给出的初始字数不低于 ',max_len) return
sentence = text[-max_len:] print('第一行为:',sentence) generate = str(sentence) generate += self._preds(sentence,length = 24-max_len,temperature=temperature) return generate def predict_hide(self, text,temperature = 1): '''预估模式4:根据给4个字,生成藏头诗五言绝句''' if not self.model: print('没有预训练模型可用于加载!') return if len(text)!=4: print('藏头诗的输入必须是4个字!') return index = random.randint(0, self.poems_num) sentence = self.poems[index][1-self.config.max_len:] + text[0] generate = str(text[0]) print('第一行为 ',sentence) for i in range(5): next_char = self._pred(sentence,temperature) sentence = sentence[1:] + next_char generate+= next_char for i in range(3): generate += text[i+1] sentence = sentence[1:] + text[i+1] for i in range(5): next_char = self._pred(sentence,temperature) sentence = sentence[1:] + next_char generate+= next_char
return generate def _preds(self,sentence,length = 23,temperature =1): ''' 供类内部调用的预估函数,输入max_len长度字符串,返回length长度的预测值字符串 sentence:预测输入值 lenth:预测出的字符串长度 ''' sentence = sentence[:self.config.max_len] generate = '' for i in range(length): pred = self._pred(sentence,temperature) generate += pred sentence = sentence[1:]+pred return generate def _pred(self,sentence,temperature =1): '''供类内部调用的预估函数,根据一串输入,返回单个预测字符''' if len(sentence) < self.config.max_len: print('in def _pred,length error ') return sentence = sentence[-self.config.max_len:] x_pred = np.zeros((1, self.config.max_len, len(self.words))) for t, char in enumerate(sentence): x_pred[0, t, self.word2idx_dic(char)] = 1. preds = self.model.predict(x_pred, verbose=0)[0] next_index = self.sample(preds,temperature=temperature) next_char = self.idx2word[next_index] return next_char
def data_generator(self): '''生成器生成数据''' i = 0 while 1: x = self.files_content[i: i + self.config.max_len] y = self.files_content[i + self.config.max_len]
if ']' in x or ']' in y: i += 1 continue
y_vec = np.zeros( shape=(1, len(self.words)), dtype=np.bool ) y_vec[0, self.word2idx_dic(y)] = 1.0
x_vec = np.zeros( shape=(1, self.config.max_len, len(self.words)), dtype=np.bool )
for t, char in enumerate(x): x_vec[0, t, self.word2idx_dic(char)] = 1.0
yield x_vec, y_vec i += 1
def train(self): '''训练模型''' print('开始训练...') number_of_epoch = len(self.files_content)-(self.config.max_len + 1)*self.poems_num number_of_epoch /= self.config.batch_size number_of_epoch = int(number_of_epoch / 1.5) print('总迭代轮次为 ',number_of_epoch) print('总诗词数量为 ',self.poems_num) print('文件内容的长度为 ',len(self.files_content))
if not self.model: self.build_model()
self.model.fit_generator( generator=self.data_generator(), verbose=True, steps_per_epoch=self.config.batch_size, epochs=number_of_epoch, callbacks=[ ModelCheckpoint(self.config.weight_file, save_weights_only=False), LambdaCallback(on_epoch_end=self.generate_sample_result) ] )
|