from __future__ import print_function
from octis.models.early_stopping.pytorchtools import EarlyStopping
import torch
import numpy as np
from octis.models.ETM_model import data
from sklearn.feature_extraction.text import CountVectorizer
from torch import nn, optim
from octis.models.ETM_model import etm
from octis.models.base_etm import BaseETM
import pickle as pkl
[docs]class ETM(BaseETM):
def __init__(
self, num_topics=10, num_epochs=100, t_hidden_size=800, rho_size=300,
embedding_size=300, activation='relu', dropout=0.5, lr=0.005,
optimizer='adam', batch_size=128, clip=0.0, wdecay=1.2e-6, bow_norm=1,
device='cpu', train_embeddings=True, embeddings_path=None,
embeddings_type='pickle', binary_embeddings=True,
headerless_embeddings=False, use_partitions=True):
"""
initialization of ETM
:param embeddings_path: string, path to embeddings file.
Can be a binary file for the 'pickle', 'keyedvectors' and
'word2vec' types or a text file for 'word2vec'.
This parameter is only used if 'train_embeddings' is set to False
:param embeddings_type: string, defines the format of the embeddings
file. Possible values are 'pickle', 'keyedvectors' or 'word2vec'.
If set to 'pickle', you must provide a file created with 'pickle'
containing an array of word embeddings, composed by words and
their respective vectors. If set to 'keyedvectors', you must
provide a file containing a saved gensim.models.KeyedVectors
instance. If set to 'word2vec', you must provide a file with the
original word2vec format. This parameter is only used if
'train_embeddings' is set to False (default 'pickle')
:param binary_embeddings: bool, indicates if the original word2vec
embeddings file is binary or textual. This parameter is only used
if both 'embeddings_type' is set to 'word2vec' and
'train_embeddings' is set to False. Otherwise, it will be ignored
(default True)
:param headerless_embeddings: bool, indicates if the original word2vec
embeddings textual file has a header line in the format
"<no_of_vectors> <vector_length>". This parameter is only used if
'embeddings_type' is set to 'word2vec', 'train_embeddings' is set
to False and 'binary_embeddings' is set to False. Otherwise, it
will be ignored (default False)
"""
super(ETM, self).__init__()
self.hyperparameters = dict()
self.hyperparameters['num_topics'] = int(num_topics)
self.hyperparameters['num_epochs'] = int(num_epochs)
self.hyperparameters['t_hidden_size'] = int(t_hidden_size)
self.hyperparameters['rho_size'] = int(rho_size)
self.hyperparameters['embedding_size'] = int(embedding_size)
self.hyperparameters['activation'] = activation
self.hyperparameters['dropout'] = float(dropout)
self.hyperparameters['lr'] = float(lr)
self.hyperparameters['optimizer'] = optimizer
self.hyperparameters['batch_size'] = int(batch_size)
self.hyperparameters['clip'] = float(clip)
self.hyperparameters['wdecay'] = float(wdecay)
self.hyperparameters['bow_norm'] = int(bow_norm)
self.hyperparameters['train_embeddings'] = bool(train_embeddings)
self.hyperparameters['embeddings_path'] = embeddings_path
assert embeddings_type in ['pickle', 'word2vec', 'keyedvectors'], \
"embeddings_type must be 'pickle', 'word2vec' or 'keyedvectors'."
self.hyperparameters['embeddings_type'] = embeddings_type
self.hyperparameters['binary_embeddings'] = binary_embeddings
self.hyperparameters['headerless_embeddings'] = headerless_embeddings
self.early_stopping = None
self.device = device
self.test_tokens, self.test_counts = None, None
self.valid_tokens, self.valid_counts = None, None
self.train_tokens, self.train_counts, self.vocab = None, None, None
self.use_partitions = use_partitions
self.model = None
self.optimizer = None
self.embeddings = None
[docs] def train_model(
self, dataset, hyperparameters=None, top_words=10,
op_path='checkpoint.pt'):
if hyperparameters is None:
hyperparameters = {}
self.set_model(dataset, hyperparameters)
self.top_words = top_words
self.early_stopping = EarlyStopping(
patience=5, verbose=True, path=op_path)
for epoch in range(0, self.hyperparameters['num_epochs']):
continue_training = self._train_epoch(epoch)
if not continue_training:
break
# load the last checkpoint with the best model
# self.model.load_state_dict(torch.load('etm_checkpoint.pt'))
if self.use_partitions:
result = self.inference()
else:
result = self.get_info()
return result
def set_model(self, dataset, hyperparameters):
if self.use_partitions:
train_data, validation_data, testing_data = (
dataset.get_partitioned_corpus(use_validation=True))
data_corpus_train = [' '.join(i) for i in train_data]
data_corpus_test = [' '.join(i) for i in testing_data]
data_corpus_val = [' '.join(i) for i in validation_data]
vocab = dataset.get_vocabulary()
self.vocab = {i: w for i, w in enumerate(vocab)}
vocab2id = {w: i for i, w in enumerate(vocab)}
(self.train_tokens, self.train_counts, self.test_tokens,
self.test_counts, self.valid_tokens, self.valid_counts
) = self.preprocess(
vocab2id, data_corpus_train, data_corpus_test, data_corpus_val)
else:
data_corpus = [' '.join(i) for i in dataset.get_corpus()]
vocab = dataset.get_vocabulary()
self.vocab = {i: w for i, w in enumerate(vocab)}
vocab2id = {w: i for i, w in enumerate(vocab)}
self.train_tokens, self.train_counts = self.preprocess(
vocab2id, data_corpus, None)
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.set_default_hyperparameters(hyperparameters)
self.load_embeddings()
# define model and optimizer
self.model = etm.ETM(
num_topics=self.hyperparameters['num_topics'],
vocab_size=len(self.vocab.keys()),
t_hidden_size=int(self.hyperparameters['t_hidden_size']),
rho_size=int(self.hyperparameters['rho_size']),
emb_size=int(self.hyperparameters['embedding_size']),
theta_act=self.hyperparameters['activation'],
embeddings=self.embeddings,
train_embeddings=self.hyperparameters['train_embeddings'],
enc_drop=self.hyperparameters['dropout']).to(self.device)
print('model: {}'.format(self.model))
self.optimizer = self.set_optimizer()
def _train_epoch(self, epoch):
self.data_list = []
self.model.train()
acc_loss = 0
acc_kl_theta_loss = 0
cnt = 0
indices = torch.arange(0, len(self.train_tokens))
indices = torch.split(indices, self.hyperparameters['batch_size'])
for idx, ind in enumerate(indices):
self.optimizer.zero_grad()
self.model.zero_grad()
data_batch = data.get_batch(
self.train_tokens, self.train_counts, ind,
len(self.vocab.keys()), self.device)
sums = data_batch.sum(1).unsqueeze(1)
if self.hyperparameters['bow_norm']:
normalized_data_batch = data_batch / sums
else:
normalized_data_batch = data_batch
recon_loss, kld_theta = self.model(
data_batch, normalized_data_batch)
total_loss = recon_loss + kld_theta
total_loss.backward()
if self.hyperparameters["clip"] > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(),
self.hyperparameters["clip"])
self.optimizer.step()
acc_loss += torch.sum(recon_loss).item()
acc_kl_theta_loss += torch.sum(kld_theta).item()
cnt += 1
log_interval = 20
if idx % log_interval == 0 and idx > 0:
cur_loss = round(acc_loss / cnt, 2)
cur_kl_theta = round(acc_kl_theta_loss / cnt, 2)
cur_real_loss = round(cur_loss + cur_kl_theta, 2)
print(
'Epoch: {} .. batch: {}/{} .. LR: {} .. KL_theta: {} ..'
' Rec_loss: {} .. NELBO: {}'.format(
epoch + 1, idx, len(indices),
self.optimizer.param_groups[0]['lr'],
cur_kl_theta, cur_loss, cur_real_loss))
self.data_list.append(normalized_data_batch)
cur_loss = round(acc_loss / cnt, 2)
cur_kl_theta = round(acc_kl_theta_loss / cnt, 2)
cur_real_loss = round(cur_loss + cur_kl_theta, 2)
print('*' * 100)
print(
'Epoch----->{} .. LR: {} .. KL_theta: {} .. '
'Rec_loss: {} .. NELBO: {}'.format(
epoch + 1, self.optimizer.param_groups[0]['lr'], cur_kl_theta,
cur_loss, cur_real_loss))
print('*' * 100)
# VALIDATION ###
if self.valid_tokens is None:
return True
else:
model = self.model.to(self.device)
model.eval()
with torch.no_grad():
val_acc_loss = 0
val_acc_kl_theta_loss = 0
val_cnt = 0
indices = torch.arange(0, len(self.valid_tokens))
indices = torch.split(
indices, self.hyperparameters['batch_size'])
for idx, ind in enumerate(indices):
self.optimizer.zero_grad()
self.model.zero_grad()
val_data_batch = data.get_batch(
self.valid_tokens, self.valid_counts,
ind, len(self.vocab.keys()), self.device)
sums = val_data_batch.sum(1).unsqueeze(1)
if self.hyperparameters['bow_norm']:
val_normalized_data_batch = val_data_batch / sums
else:
val_normalized_data_batch = val_data_batch
val_recon_loss, val_kld_theta = self.model(
val_data_batch, val_normalized_data_batch)
val_acc_loss += torch.sum(val_recon_loss).item()
val_acc_kl_theta_loss += torch.sum(val_kld_theta).item()
val_cnt += 1
val_total_loss = val_recon_loss + val_kld_theta
val_cur_loss = round(val_acc_loss / cnt, 2)
val_cur_kl_theta = round(val_acc_kl_theta_loss / cnt, 2)
val_cur_real_loss = round(val_cur_loss + val_cur_kl_theta, 2)
print('*' * 100)
print(
'VALIDATION .. LR: {} .. KL_theta: {} .. Rec_loss: {}'
' .. NELBO: {}'.format(
self.optimizer.param_groups[0]['lr'], val_cur_kl_theta,
val_cur_loss, val_cur_real_loss))
print('*' * 100)
if np.isnan(val_cur_real_loss):
return False
else:
self.early_stopping(val_total_loss, model)
if self.early_stopping.early_stop:
print("Early stopping")
return False
else:
return True
def get_info(self):
topic_w = []
self.model.eval()
info = {}
with torch.no_grad():
theta, _ = self.model.get_theta(torch.cat(self.data_list))
gammas = self.model.get_beta().cpu().numpy()
for k in range(self.hyperparameters['num_topics']):
if np.isnan(gammas[k]).any():
# to deal with nan matrices
topic_w = None
break
else:
top_words = list(
gammas[k].argsort()[-self.top_words:][::-1])
topic_words = [self.vocab[a] for a in top_words]
topic_w.append(topic_words)
info['topic-word-matrix'] = gammas
info['topic-document-matrix'] = theta.cpu().detach().numpy().T
info['topics'] = topic_w
return info
def inference(self):
assert isinstance(self.use_partitions, bool) and self.use_partitions
topic_d = []
self.model.eval()
indices = torch.arange(0, len(self.test_tokens))
indices = torch.split(indices, self.hyperparameters['batch_size'])
for idx, ind in enumerate(indices):
data_batch = data.get_batch(self.test_tokens, self.test_counts,
ind, len(self.vocab.keys()),
self.device)
sums = data_batch.sum(1).unsqueeze(1)
if self.hyperparameters['bow_norm']:
normalized_data_batch = data_batch / sums
else:
normalized_data_batch = data_batch
theta, _ = self.model.get_theta(normalized_data_batch)
topic_d.append(theta.cpu().detach().numpy())
info = self.get_info()
emp_array = np.empty((0, self.hyperparameters['num_topics']))
# batch concatenation
for i in range(len(topic_d)):
emp_array = np.concatenate([emp_array, topic_d[i]])
info['test-topic-document-matrix'] = emp_array.T
return info
def set_default_hyperparameters(self, hyperparameters):
for k in hyperparameters.keys():
if k in self.hyperparameters.keys():
self.hyperparameters[k] = hyperparameters.get(
k, self.hyperparameters[k])
def partitioning(self, use_partitions=False):
self.use_partitions = use_partitions
@staticmethod
def preprocess(
vocab2id, train_corpus, test_corpus=None, validation_corpus=None):
def split_bow(bow_in, n_docs):
indices = [
[w for w in bow_in[doc, :].indices] for doc in range(n_docs)]
counts = [
[c for c in bow_in[doc, :].data] for doc in range(n_docs)]
return indices, counts
vec = CountVectorizer(
vocabulary=vocab2id, token_pattern=r'(?u)\b\w+\b')
dataset = train_corpus.copy()
if test_corpus is not None:
dataset.extend(test_corpus)
if validation_corpus is not None:
dataset.extend(validation_corpus)
vec.fit(dataset)
idx2token = {v: k for (k, v) in vec.vocabulary_.items()}
x_train = vec.transform(train_corpus)
x_train_tokens, x_train_count = split_bow(x_train, x_train.shape[0])
if test_corpus is not None:
x_test = vec.transform(test_corpus)
x_test_tokens, x_test_count = split_bow(x_test, x_test.shape[0])
if validation_corpus is not None:
x_validation = vec.transform(validation_corpus)
x_val_tokens, x_val_count = split_bow(
x_validation, x_validation.shape[0])
return (
x_train_tokens, x_train_count, x_test_tokens,
x_test_count, x_val_tokens, x_val_count)
else:
return (
x_train_tokens, x_train_count, x_test_tokens, x_test_count)
else:
if validation_corpus is not None:
x_validation = vec.transform(validation_corpus)
x_val_tokens, x_val_count = split_bow(
x_validation, x_validation.shape[0])
return x_train_tokens, x_train_count, x_val_tokens, x_val_count
else:
return x_train_tokens, x_train_count