-
Notifications
You must be signed in to change notification settings - Fork 4
/
state.h
300 lines (240 loc) · 10.6 KB
/
state.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
#ifndef STATE_H
#define STATE_H
#include "corpus.h"
#include <map>
#include <vector>
using namespace std;
class hdp_hyperparameter
{
/// hyperparameters
public:
double m_gamma_a;
double m_gamma_b;
double m_alpha_a;
double m_alpha_b;
int m_max_iter;
public:
void setup_parameters(double _gamma_a, double _gamma_b,
double _alpha_a, double _alpha_b,
int _max_iter)
{
m_gamma_a = _gamma_a;
m_gamma_b = _gamma_b;
m_alpha_a = _alpha_a;
m_alpha_b = _alpha_b;
m_max_iter = _max_iter;
}
void copy_parameters(const hdp_hyperparameter _input_hyperparam);
};
typedef vector<int> int_vec; // define the vector of int
typedef vector<double> double_vec; // define the vector of double
/// word info structure used in the main class
struct word_info
{
public:
int m_word_index;
int m_table_assignment;
//int m_topic_assignment; // this is extra information
};
class doc_state
{
public:
int m_doc_id; // document id
int m_doc_length; // document length
int m_num_tables; // number of tables in this document
word_info * m_words;
int_vec m_table_to_topic; // for a doc, translate its table index to topic index
int_vec m_word_counts_by_t; // word counts for each table
//vector < vector<int> > m_words_by_zi; // stores the word idx indexed by z then i
public:
doc_state();
~doc_state();
public:
void setup_state_from_doc(const document * doc);
void free_doc_state();
};
class counts
{
public:
int m_num_words;
int m_total_num_tables;
int_vec m_num_tables_by_z;
int_vec m_word_counts_by_z;
vector <int*> m_word_counts_by_zw;
counts();
counts(int _m_num_words, int _m_total_num_tables, const int_vec& _m_num_tables_by_z, const int_vec& _m_word_counts_by_z, const vector<int*>& _m_word_counts_by_zw);
~counts();
void set_counts(int _m_num_words, int _m_total_num_tables, const int_vec& _m_num_tables_by_z,
const int_vec& _m_word_counts_by_z,
const vector<int*>& _m_word_counts_by_zw);
};
class word_counts
{
public:
int_vec m_word_counts_by_z;
vector<int*> m_word_counts_by_zw;
word_counts();
~word_counts();
};
class hdp_state
{
public:
/// doc information, fix value
int m_size_vocab;
int m_total_words;
int m_num_docs;
/// document states
doc_state** m_doc_states;
/// number of topics
int m_num_topics;
/// total number of tables for all topics
int m_total_num_tables;
/// by_z, by topic
/// by_d, by document, for each topic
/// by_w, by word, for each topic
/// by_t, by table for each document
int_vec m_num_tables_by_z; // how many tables each topic has
int_vec m_word_counts_by_z; // word counts for each topic
vector <int*> m_word_counts_by_zd; // word counts for [each topic, each doc]
vector <int*> m_word_counts_by_zw; // word counts for [each topic, each word]
/// for online updates (local changes for the current document)
// int_vec m_num_tables_by_z_update;
// int_vec m_word_counts_by_z_update;
int_vec m_word_counts_by_zd_online;
// vector <int*> m_word_counts_by_zw_update;
int m_num_topics_before_update; // number of topics from the batch plus possible new topics
///
/// topic Dirichlet parameter
double m_eta;
/// including concentration parameters
double m_gamma;
double m_alpha;
public:
hdp_state();
virtual ~hdp_state();
public:
void set_vocab_size(int _vocab_size);
void setup_state_from_corpus(const corpus* c);
void setup_doc_info_from_document(const document * doc);
void allocate_initial_space();
void allocate_space_for_online_updates();
void free_state();
void free_doc_info();
void free_counts_update();
void free_online_info();
void iterate_gibbs_state(bool remove, bool permute,
hdp_hyperparameter* hdp_hyperparam,
bool table_sampling=false);
void iterate_gibbs_state_online(bool remove, hdp_hyperparameter* hdp_hyperparam,
bool table_sampling=false);
void sample_tables(doc_state* d_state, double_vec & q, double_vec & f);
void sample_tables_online(doc_state* d_state, double_vec & q, double_vec & f);
void sample_table_assignment(doc_state* d_state, int t, int* words, double_vec & q, double_vec & f);
void sample_table_assignment_online(doc_state* d_state, int t, int* words, double_vec & q, double_vec & f);
void sample_word_assignment(doc_state* d_state, int i, bool remove, double_vec & q, double_vec & f);
void sample_word_assignment_online(doc_state* d_state, int i, bool remove, double_vec & q, double_vec & f);
void doc_state_update(doc_state* d_state, int i, int update, int k=-1);
void doc_state_update_online(doc_state* d_state, int i, int update, int k=-1);
void compact_doc_state(doc_state* d_state, int* k_to_new_k);
void compact_doc_state_online(doc_state* d_state, int* k_to_new_k);
void compact_hdp_state();
void compact_hdp_state_online();
void save_state_to_file(char * name);
void save_state(char * name);
void save_state_ex(char * name, int sampler_num = -1);
void load_state_from_file(char * name);
void load_state_ex(char * name, int sampler_num = -1);
};
class hdp_state_dynamic
{
public:
/// doc information, fix value
int m_size_vocab;
int m_total_words;
int m_num_docs;
/// document states
doc_state** m_doc_states;
/// number of topics
int m_num_topics;
/// total number of tables for all topics (during the iteration of batch this is the total number of tables for the documents seen so far)
int m_total_num_tables;
/// by_z, by topic
/// by_d, by document, for each topic
/// by_w, by word, for each topic
/// by_t, by table for each document
int_vec m_word_counts_by_z; // word counts for each topic
int_vec m_num_tables_by_z; // how many tables each topic has (during the iterations of batch this is the number of tables for the documents seen so far)
vector<int*> m_num_tables_by_zd; // how many tables [each topic, (in) each doc] has
int_vec m_total_num_tables_by_d; // total number of tables for each doc
vector <int*> m_word_counts_by_zd; // word counts for [each topic, each doc]
vector <int*> m_word_counts_by_zw; // word counts for [each topic, each word]
/// for online updates (local changes for the current document)
int_vec m_word_counts_by_zd_online;
int_vec m_num_tables_by_zd_online;
int_vec m_num_tables_by_zd_for_prev_doc_online;
int m_total_num_tables_by_d_online;
int m_total_num_tables_by_d_for_prev_doc_online;
int m_num_topics_before_update; // number of topics from the batch
///
/// topic Dirichlet parameter
double m_eta;
/// including concentration parameters
double m_gamma;
double m_alpha;
double m_delta;
public:
hdp_state_dynamic();
virtual ~hdp_state_dynamic();
public:
void set_vocab_size(int _vocab_size);
void setup_state_from_corpus(const corpus* c);
void setup_doc_info_from_document(const document * doc);
void allocate_initial_space();
void allocate_space_for_online_updates();
void free_state();
void free_doc_info();
void free_counts_update();
void free_online_info();
void free_current_doc_counts_update();
void iterate_gibbs_state(bool remove,
hdp_hyperparameter* hdp_hyperparam,
bool table_sampling=false);
void iterate_gibbs_state_online(bool remove, hdp_hyperparameter* hdp_hyperparam,
bool table_sampling=false);
void reset_counts_up_to_date(int doc_id = 0);
void sample_tables(doc_state* d_state, bool not_init, double_vec & q, double_vec & f);
void sample_tables_online(doc_state* d_state, double_vec & q, double_vec & f);
void sample_table_assignment(doc_state* d_state, int t, bool not_init,
int* words, double_vec & q, double_vec & f);
void sample_table_assignment_online(doc_state* d_state, int t, int* words, double_vec & q, double_vec & f);
void sample_word_assignment(doc_state* d_state, int i, bool remove, double_vec & q, double_vec & f);
void sample_word_assignment_online(doc_state* d_state, int i, bool remove, double_vec & q, double_vec & f);
int sample_new_topic_for_new_table_online(const vector<double>& likelihood_per_topic_assignment) const;
double compute_prior_probability_for_topic_sampling(int doc_id, int topic_id,
bool decrease_table_counts = false) const;
double compute_prior_probability_for_topic_sampling_online(int topic_id,
bool decrease_table_counts = false) const;
double compute_normalisation_constant_for_topic_prior_online() const;
void doc_state_update(doc_state* d_state, int i, int update, int k=-1);
void doc_state_update_online(doc_state* d_state, int i, int update, int k=-1);
void compact_doc_state(doc_state* d_state, int* k_to_new_k);
void compact_doc_state_online(doc_state* d_state, int* k_to_new_k);
void compact_hdp_state();
void compact_hdp_state_online();
double compute_next_doc_table_assignment_likelihood(int doc_id, int old_topic_id = -1, int new_topic_id = -1, bool new_table = false) const;
double compute_next_table_assignment_likelihood(int doc_id, int old_topic_id = -1,
int new_topic_id = -1, bool new_table = false) const;
double compute_next_table_assignment_likelihood_for_non_born_topics(int doc_id,
const vector<int>& updating_num_tables_by_z,
int old_topic_id = -1,
int new_topic_id = -1,
bool update_counts_for_current_doc = false) const;
void save_state_to_file(char * name);
void save_state_ex(char * name, int sampler_num = -1);
void load_state_from_file(char * name);
void load_state_ex(char * name, int sampler_num = -1);
void save_online_update_for_next_iteration();
void update_state_after_online_update();
void save_last_document_info_for_online_update();
};
#endif // STATE_H