forked from rafaljozefowicz/lm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
language_model.py
164 lines (133 loc) · 6.99 KB
/
language_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import tensorflow as tf
from model_utils import sharded_variable, LSTMCell
from common import assign_to_gpu, average_grads, find_trainable_variables
from hparams import HParams
class LM(object):
def __init__(self, hps, mode="train", ps_device="/gpu:0"):
self.hps = hps
data_size = hps.batch_size * hps.num_gpus
self.x = tf.placeholder(tf.int32, [data_size, hps.num_steps])
self.y = tf.placeholder(tf.int32, [data_size, hps.num_steps])
self.w = tf.placeholder(tf.int32, [data_size, hps.num_steps])
losses = []
tower_grads = []
xs = tf.split(0, hps.num_gpus, self.x)
ys = tf.split(0, hps.num_gpus, self.y)
ws = tf.split(0, hps.num_gpus, self.w)
for i in range(hps.num_gpus):
with tf.device(assign_to_gpu(i, ps_device)), tf.variable_scope(tf.get_variable_scope(),
reuse=True if i > 0 else None):
loss = self._forward(i, xs[i], ys[i], ws[i])
losses += [loss]
if mode == "train":
cur_grads = self._backward(loss, summaries=(i == hps.num_gpus - 1))
tower_grads += [cur_grads]
self.loss = tf.add_n(losses) / len(losses)
tf.scalar_summary("model/loss", self.loss)
self.global_step = tf.get_variable("global_step", [], tf.int32, initializer=tf.zeros_initializer,
trainable=False)
if mode == "train":
grads = average_grads(tower_grads)
optimizer = tf.train.AdagradOptimizer(hps.learning_rate, initial_accumulator_value=1.0)
self.train_op = optimizer.apply_gradients(grads, global_step=self.global_step)
self.summary_op = tf.merge_all_summaries()
else:
self.train_op = tf.no_op()
if mode in ["train", "eval"] and hps.average_params:
with tf.name_scope(None): # This is needed due to EMA implementation silliness.
# Keep track of moving average of LSTM variables.
ema = tf.train.ExponentialMovingAverage(decay=0.999)
variables_to_average = find_trainable_variables("LSTM")
self.train_op = tf.group(*[self.train_op, ema.apply(variables_to_average)])
self.avg_dict = ema.variables_to_restore(variables_to_average)
def _forward(self, gpu, x, y, w):
hps = self.hps
w = tf.to_float(w)
self.initial_states = []
for i in range(hps.num_layers):
with tf.device("/gpu:%d" % gpu):
v = tf.Variable(tf.zeros([hps.batch_size, hps.state_size + hps.projected_size]), trainable=False,
collections=[tf.GraphKeys.LOCAL_VARIABLES], name="state_%d_%d" % (gpu, i))
self.initial_states += [v]
emb_vars = sharded_variable("emb", [hps.vocab_size, hps.emb_size], hps.num_shards)
x = tf.nn.embedding_lookup(emb_vars, x) # [bs, steps, emb_size]
if hps.keep_prob < 1.0:
x = tf.nn.dropout(x, hps.keep_prob)
inputs = [tf.squeeze(v, [1]) for v in tf.split(1, hps.num_steps, x)]
for i in range(hps.num_layers):
with tf.variable_scope("lstm_%d" % i):
cell = LSTMCell(hps.state_size, hps.emb_size, num_proj=hps.projected_size)
state = self.initial_states[i]
for t in range(hps.num_steps):
inputs[t], state = cell(inputs[t], state)
if hps.keep_prob < 1.0:
inputs[t] = tf.nn.dropout(inputs[t], hps.keep_prob)
with tf.control_dependencies([self.initial_states[i].assign(state)]):
inputs[t] = tf.identity(inputs[t])
inputs = tf.reshape(tf.concat(1, inputs), [-1, hps.projected_size])
# Initialization ignores the fact that softmax_w is transposed. That worked slightly better.
softmax_w = sharded_variable("softmax_w", [hps.vocab_size, hps.projected_size], hps.num_shards)
softmax_b = tf.get_variable("softmax_b", [hps.vocab_size])
if hps.num_sampled == 0:
full_softmax_w = tf.reshape(tf.concat(1, softmax_w), [-1, hps.projected_size])
full_softmax_w = full_softmax_w[:hps.vocab_size, :]
logits = tf.matmul(inputs, full_softmax_w, transpose_b=True) + softmax_b
# targets = tf.reshape(tf.transpose(self.y), [-1])
targets = tf.reshape(y, [-1])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, targets)
else:
targets = tf.reshape(y, [-1, 1])
loss = tf.nn.sampled_softmax_loss(softmax_w, softmax_b, tf.to_float(inputs),
targets, hps.num_sampled, hps.vocab_size)
loss = tf.reduce_mean(loss * tf.reshape(w, [-1]))
return loss
def _backward(self, loss, summaries=False):
hps = self.hps
loss = loss * hps.num_steps
emb_vars = find_trainable_variables("emb")
lstm_vars = find_trainable_variables("LSTM")
softmax_vars = find_trainable_variables("softmax")
all_vars = emb_vars + lstm_vars + softmax_vars
grads = tf.gradients(loss, all_vars)
orig_grads = grads[:]
emb_grads = grads[:len(emb_vars)]
grads = grads[len(emb_vars):]
for i in range(len(emb_grads)):
assert isinstance(emb_grads[i], tf.IndexedSlices)
emb_grads[i] = tf.IndexedSlices(emb_grads[i].values * hps.batch_size, emb_grads[i].indices,
emb_grads[i].dense_shape)
lstm_grads = grads[:len(lstm_vars)]
softmax_grads = grads[len(lstm_vars):]
lstm_grads, lstm_norm = tf.clip_by_global_norm(lstm_grads, hps.max_grad_norm)
clipped_grads = emb_grads + lstm_grads + softmax_grads
assert len(clipped_grads) == len(orig_grads)
if summaries:
tf.scalar_summary("model/lstm_grad_norm", lstm_norm)
tf.scalar_summary("model/lstm_grad_scale", tf.minimum(hps.max_grad_norm / lstm_norm, 1.0))
tf.scalar_summary("model/lstm_weight_norm", tf.global_norm(lstm_vars))
# for v, g, cg in zip(all_vars, orig_grads, clipped_grads):
# name = v.name.lstrip("model/")
# tf.histogram_summary(name + "/var", v)
# tf.histogram_summary(name + "/grad", g)
# tf.histogram_summary(name + "/clipped_grad", cg)
return list(zip(clipped_grads, all_vars))
@staticmethod
def get_default_hparams():
return HParams(
batch_size=128,
num_steps=20,
num_shards=8,
num_layers=1,
learning_rate=0.2,
max_grad_norm=10.0,
num_delayed_steps=150,
keep_prob=0.9,
vocab_size=793470,
emb_size=512,
state_size=2048,
projected_size=512,
num_sampled=8192,
num_gpus=1,
average_params=True,
run_profiler=False,
)