diff --git a/CNN-task-3.ipynb b/CNN-task-3.ipynb new file mode 100644 index 0000000..cebf8dd --- /dev/null +++ b/CNN-task-3.ipynb @@ -0,0 +1,1254 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 62, + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5" + }, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "import torch\n", + "import pickle\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "\n", + "import warnings\n", + "from matplotlib import pyplot as plt\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "from pandas.core.common import SettingWithCopyWarning\n", + "warnings.simplefilter(action=\"ignore\", category=SettingWithCopyWarning)\n", + "\n", + "from torch.utils.data import (TensorDataset, DataLoader, RandomSampler,SequentialSampler)\n", + "\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '6,7'\n", + "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [], + "source": [ + "#Used to return first n items of the iterable as a list\n", + "from itertools import islice\n", + "\n", + "def take(n, iterable):\n", + " return list(islice(iterable, n))" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=================================\n", + "GPU found\n", + "Using GPU at cuda: 0\n", + "=================================\n", + " \n" + ] + } + ], + "source": [ + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", + "if device == 'cuda':\n", + " print(\"=================================\")\n", + " print(\"GPU found\")\n", + " print(\"Using GPU at cuda:\",torch.cuda.current_device())\n", + " print(\"=================================\")\n", + " print(\" \")" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [], + "source": [ + "model1 = torch.load(\"~/NNTI-WS2021-NLP-Project/Project_files/Hindi/epoch_300.pt\")\n", + "\n", + "w1 = model1[\"w1.weight\"].T\n", + "w2 = model1[\"w2.weight\"]\n", + "\n", + "\n", + "cleandata = pd.read_pickle(\"~/NNTI-WS2021-NLP-Project/Project_files/Hindi/hindi_corpus_cleaned.pkl\")\n", + "word_index = pd.read_pickle(\"~/NNTI-WS2021-NLP-Project/Project_files/Hindi/word_index.pkl\")\n", + "index_word = pd.read_pickle(\"~/NNTI-WS2021-NLP-Project/Project_files/Hindi/index_word.pkl\")\n", + "V = pd.read_pickle(\"~/NNTI-WS2021-NLP-Project/Project_files/Hindi/vocab.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_csv('https://raw.githubusercontent.com/SouravDutta91/NNTI-WS2021-NLP-Project/main/data/hindi_hatespeech.tsv',sep='\\t')\n", + "text = data[['text','task_1']]\n", + "text['text'] = cleandata['text'].apply(lambda x: x.split())\n", + "text['label'] = text['task_1'].apply(lambda x: 1 if x == 'HOF' else 0)\n", + "max_len = text.text.str.len().max()" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [], + "source": [ + "#Function to count the labels in a dataset\n", + "def tag_count(input):\n", + " hcount,ncount = 0,0\n", + " for tag in input:\n", + " if tag == 1:\n", + " hcount+=1\n", + " else:\n", + " ncount+=1\n", + " return hcount,ncount" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [], + "source": [ + "'''Word-index dictionaries are updated with '' word which is used for padding,\n", + "that is to make the sentences uniform in length'''\n", + "\n", + "word_index[''] = len(V)\n", + "index_word[len(V)] = ''\n" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [], + "source": [ + "#Function to get the word embedding from the weight\n", + "def get_word_embedding(input):\n", + " index = word_index[input]\n", + " return w1[index]" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [], + "source": [ + "#Creates the embedding matrix by using the word embeddings and adds zeroes for all the '' words\n", + "def matrix_embeddings():\n", + " _ , emb_size = w1.shape\n", + " embedding_matrix = np.random.uniform(-1, 1, (len(word_index), emb_size))\n", + " embedding_matrix[word_index['']] = np.zeros((emb_size,))\n", + "\n", + " for word,i in take(len(V),word_index.items()):\n", + " temp = get_word_embedding(word)\n", + " if temp is not None:\n", + " embedding_matrix[i] = temp.cpu()\n", + " return embedding_matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [], + "source": [ + "#Encodes the sentences into their respective indexes\n", + "def encode(corpus):\n", + " sent_idx = []\n", + " i = 0\n", + " for sentence in corpus:\n", + " sentence += [''] * (max_len - len(sentence))\n", + " idx = [word_index[word] for word in sentence]\n", + " sent_idx.append(idx)\n", + " i+= 1\n", + " return np.array(sent_idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [], + "source": [ + "#store the encoding,labels and the embeddings\n", + "encoded_text = encode(text.text)\n", + "labels = np.array(text['label'])\n", + "embeds = torch.tensor(matrix_embeddings())" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-0.3367, 0.3636, -0.3046, ..., 0.0300, -0.2679, -0.4148],\n", + " [ 0.0883, 0.0739, 0.1094, ..., -0.1064, -0.0681, -0.2677],\n", + " [-0.2731, 0.1393, 0.2855, ..., 0.0173, 0.2927, 0.2280],\n", + " ...,\n", + " [ 0.2724, -0.0797, -0.0919, ..., -0.3751, 0.1591, 0.1360],\n", + " [-0.0535, 0.1561, -0.0446, ..., 0.0316, -0.3634, -0.2427],\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", + " dtype=torch.float64)" + ] + }, + "execution_count": 73, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeds" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [], + "source": [ + "#split the data into train and test sets with train set being 0.8 and test being the remaining 0.2\n", + "xtrain, xtest, ytrain, ytest = train_test_split(encoded_text,labels,shuffle=True,test_size=0.2,random_state=15)\n", + "\n", + "\n", + "#Creating a dataloader for train and test sets\n", + "def get_dataloader(traindata, testdata, trainlabels, testlabels ,batchsize):\n", + " \n", + " traindata = torch.tensor(traindata).float()\n", + " testdata = torch.tensor(testdata).float()\n", + " trainlabels = torch.tensor(trainlabels)\n", + " testlabels = torch.tensor(testlabels)\n", + " \n", + " train = TensorDataset(traindata,trainlabels)\n", + " train_dataload = DataLoader(train,sampler=RandomSampler(train),batch_size=batchsize,drop_last=True)\n", + " test = TensorDataset(testdata,testlabels)\n", + " test_dataload = DataLoader(test,sampler=RandomSampler(test),batch_size=batchsize,drop_last=True)\n", + "\n", + " return train_dataload,test_dataload\n", + "\n", + "batchsize = 50\n", + "train_dataload,test_dataload = get_dataloader(xtrain, xtest, ytrain, ytest,batchsize)" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The train dataset has 1967 HOF labels and 1765 NOT labels.\n", + "The test dataset has 502 HOF labels and 431 NOT labels.\n" + ] + } + ], + "source": [ + "#Return the label count\n", + "th,tc = tag_count(ytrain)\n", + "testh,testc = tag_count(ytest)\n", + " \n", + "print(\"The train dataset has {} HOF labels and {} NOT labels.\".format(th,tc))\n", + "print(\"The test dataset has {} HOF labels and {} NOT labels.\".format(testh,testc))" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "#Creating the CNN model\n", + "class hindi_cnnmodel(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.word_embed = embeds\n", + " self.filter_sizes = [2, 3, 4] #Size of the Kernel\n", + " self.num_filters = [50, 50, 50] #3 convolution layers each with 50 filters\n", + " self.num_classes=2 #Number of classes\n", + " self.dropout=0.5 #Prevents overfitting\n", + " self.vlen,self.es = self.word_embed.shape\n", + " self.embedding = nn.Embedding.from_pretrained(self.word_embed) #Loading the trained word embedding matrix\n", + "\n", + " #1D convolution is used to detect the features in the sentences. Each filter returns a feature map \n", + " self.conv1d = nn.ModuleList([nn.Conv1d(in_channels=self.es,out_channels=self.num_filters[i],kernel_size=self.filter_sizes[i])\n", + " for i in range(len(self.filter_sizes))])\n", + " \n", + " self.fc = nn.Linear(np.sum(self.num_filters), self.num_classes)\n", + " self.dropout1 = nn.Dropout(p=self.dropout)\n", + "\n", + " def forward(self,input1):\n", + " x_e = self.embedding(input1).float()\n", + " x_r = x_e.permute(0,2,1)\n", + " #ReLU and maxpool is used to reduce the feature map into a single scalar\n", + " #Maxpool will capture the best feature from the feature map\n", + " conv_list = [F.relu(conv(x_r)) for conv in self.conv1d]\n", + " x_maxpool = [F.max_pool1d(x_conv, kernel_size=x_conv.shape[2]) for x_conv in conv_list]\n", + " \n", + " #Fully connected layer\n", + " x_fc = torch.cat([x_pool.squeeze(dim=2) for x_pool in x_maxpool],dim=1)\n", + " drop = self.dropout1(x_fc)\n", + " output = self.fc(drop)\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hindi_cnnmodel(\n", + " (embedding): Embedding(17788, 300)\n", + " (conv1d): ModuleList(\n", + " (0): Conv1d(300, 50, kernel_size=(2,), stride=(1,))\n", + " (1): Conv1d(300, 50, kernel_size=(3,), stride=(1,))\n", + " (2): Conv1d(300, 50, kernel_size=(4,), stride=(1,))\n", + " )\n", + " (fc): Linear(in_features=150, out_features=2, bias=True)\n", + " (dropout1): Dropout(p=0.5, inplace=False)\n", + ")\n" + ] + } + ], + "source": [ + "#Initializing the model\n", + "model = hindi_cnnmodel()\n", + "\n", + "#Sending the model to GPU\n", + "model.cuda()\n", + "\n", + "print(model)\n", + "\n", + "#Setting the paraters\n", + "learning_rate = 0.05\n", + "epochs = 10\n", + "\n", + "#We use SGD optimizer and CrossEntropyLoss\n", + "optimizer = optim.SGD(model.parameters(),lr=learning_rate)\n", + "criterion = nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [], + "source": [ + "def train(train_dataload,test_dataload):\n", + " accuracy = 0\n", + " print(\"Training started\")\n", + " start = time.time()\n", + " losses = []\n", + " test_loss = []\n", + "\n", + " for epoch in range(epochs):\n", + " totloss = 0\n", + " model.train()\n", + " \n", + " #Take a batch at a time from the dataloader\n", + " for i,batch in enumerate(train_dataload):\n", + " #Send the input and the label to GPU \n", + " x_train,y_train = tuple(x.to(torch.int64).cuda() for x in batch)\n", + " print(type(x_train))\n", + " #Compute the loss\n", + " model.zero_grad() #Make previous calculated gradients zero\n", + " output = model(x_train)\n", + " loss = criterion(output,y_train)\n", + " totloss += loss.item() \n", + " loss.backward() #Compute gradients\n", + " optimizer.step() #updates weights\n", + " \n", + " # Calculate the average loss over the entire training data\n", + " train_loss = totloss/len(train_dataload)\n", + " losses.append(train_loss) \n", + " \n", + " testloss,testacc = test_acc(test_dataload)\n", + " test_loss.append(testloss)\n", + " \n", + " if testacc > accuracy:\n", + " accuracy = testacc\n", + " \n", + " print(\"At epoch {} the training loss is {}, the test loss is {} and accuracy of {}%\".format(epoch,round(testloss,3),round(train_loss,3),round(accuracy,2)))\n", + " \n", + " #Plotting the train loss and the test loss at the end\n", + " plt.plot(losses,label=\"train\")\n", + " plt.plot(test_loss,label=\"test\")\n", + " plt.xlabel(\"EPOCHS\")\n", + " plt.ylabel(\"LOSS\")\n", + " plt.legend()\n", + " plt.show()\n", + " print(\"Training ended\")" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "def test_acc(test_dataload):\n", + " #By putting the model into evaluation mode, the dropout layers are stopped for the time being\n", + " model.eval()\n", + " testacc = []\n", + " testloss = []\n", + "\n", + " for batch in test_dataload:\n", + " x_test,y_test = tuple(x.to(torch.int64).cuda() for x in batch)\n", + "\n", + " with torch.no_grad():\n", + " output = model(x_test) #Computing logits\n", + " #Calculate the test loss\n", + " loss = criterion(output,y_test)\n", + " testloss.append(loss.item())\n", + " \n", + " #Calculate the predictions and its accuracy\n", + " preds = torch.argmax(output,dim=1).flatten()\n", + "\n", + " accu = (preds == y_test).cpu().numpy().mean() * 100\n", + " testacc.append(accu)\n", + " \n", + " x = np.mean(testloss)\n", + " y = np.mean(testacc)\n", + "\n", + " return x,y" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training started\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 0 the training loss is 0.565, the test loss is 0.639 and accuracy of 77.44%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 1 the training loss is 0.482, the test loss is 0.528 and accuracy of 78.11%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 2 the training loss is 0.45, the test loss is 0.451 and accuracy of 78.78%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 3 the training loss is 0.436, the test loss is 0.41 and accuracy of 79.56%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 4 the training loss is 0.423, the test loss is 0.376 and accuracy of 79.89%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 5 the training loss is 0.42, the test loss is 0.347 and accuracy of 80.56%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 6 the training loss is 0.426, the test loss is 0.317 and accuracy of 80.56%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 7 the training loss is 0.415, the test loss is 0.295 and accuracy of 81.0%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 8 the training loss is 0.414, the test loss is 0.266 and accuracy of 81.11%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 9 the training loss is 0.423, the test loss is 0.251 and accuracy of 81.11%\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEGCAYAAAB/+QKOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAA2OElEQVR4nO3deXhV5bX48e/KTCBAgDCGkDATQKYwKCAoICAKWocqzqDYXq36U7nF1lbFetV6a9WrtVVBrBMOrYoKClZwQFHCJCRMIQxJmMIQ5szr98fewCEmBJJzsjOsz/OcJ2dP71454ll5h/2+oqoYY4wxJQV5HYAxxpjqyRKEMcaYUlmCMMYYUypLEMYYY0plCcIYY0ypQrwOwF+aNWum8fHxXodhjDE1yrJly/aoakxpx2pNgoiPjyc5OdnrMIwxpkYRka1lHQtoE5OIjBGR9SKSJiLTyjjnahFJFZEUEXnLZ3+RiKx0X3MCGacxxpifC1gNQkSCgReAUUAmsFRE5qhqqs85nYAHgMGqul9EmvsUcUxVewcqPmOMMacXyBrEACBNVdNVNR+YDUwocc5twAuquh9AVXcHMB5jjDFnIZB9EG2ADJ/tTGBgiXM6A4jIYiAYeFhVP3OPRYhIMlAIPKGqH5a8gYhMAaYAxMXF+TV4Y0zdUFBQQGZmJrm5uV6HElARERHExsYSGhp6xtd43UkdAnQChgOxwNci0lNVc4B2qpolIu2BL0Vktapu8r1YVV8CXgJISkqySaWMMWctMzOTqKgo4uPjERGvwwkIVWXv3r1kZmaSkJBwxtcFsokpC2jrsx3r7vOVCcxR1QJV3QxswEkYqGqW+zMdWAT0CWCsxpg6Kjc3l6ZNm9ba5AAgIjRt2vSsa0mBTBBLgU4ikiAiYcA1QMnRSB/i1B4QkWY4TU7pIhItIuE++wcDqRhjTADU5uRwXEV+x4AlCFUtBO4EPgfWAu+qaoqITBeR8e5pnwN7RSQVWAhMVdW9QDcgWURWufuf8B395E9H8gr582fr2Lb3aCCKN8aYGiugz0Go6lxV7ayqHVT1MXffH1V1jvteVfVeVU1U1Z6qOtvd/5273cv9OSNQMR7KLWTWd1uY/klKoG5hjDFlysnJ4W9/+9tZX3fxxReTk5Pj/4B81Pm5mFo2iuCuEZ34Yu1uvly3y+twjDF1TFkJorCw8LTXzZ07l8aNGwcoKkedTxAAkwYn0CGmPo98nEpuQZHX4Rhj6pBp06axadMmevfuTf/+/Rk6dCjjx48nMTERgMsuu4x+/frRvXt3XnrppRPXxcfHs2fPHrZs2UK3bt247bbb6N69OxdddBHHjh3zS2xeD3OtFsJCgnh4fHdumPEjL3+dzm9GdPI6JGOMBx75OIXU7Qf9WmZi64Y8dGn3Mo8/8cQTrFmzhpUrV7Jo0SLGjRvHmjVrTgxHnTlzJk2aNOHYsWP079+fK664gqZNm55SxsaNG3n77bd5+eWXufrqq/nXv/7F9ddfX+nYrQbhGtophrE9WvLCojQy91uHtTHGGwMGDDjlWYXnnnuOXr16MWjQIDIyMti4cePPrklISKB3794A9OvXjy1btvglFqtB+HjwkkQWrt/Nnz5Zy99v6Od1OMaYKna6v/SrSv369U+8X7RoEV988QXff/89kZGRDB8+vNRnGcLDw0+8Dw4O9lsTk9UgfLRpXI87L+jIZyk7+XpDttfhGGPqgKioKA4dOlTqsQMHDhAdHU1kZCTr1q1jyZIlVRqbJYgSbju/PfFNI3n44xTyC4u9DscYU8s1bdqUwYMH06NHD6ZOnXrKsTFjxlBYWEi3bt2YNm0agwYNqtLYRLV2TGGUlJSk/lowaOG63dwyaynTxnblV8M6+KVMY0z1tHbtWrp16+Z1GFWitN9VRJapalJp51sNohQXdG3OyG4teO4/G9lxwD9tecYYU9NYgijDQ5cmUlisPPbpWq9DMcYYT1iCKEPbJpH8elgHPvlpB99t2uN1OMYYU+UsQZzGr4d3IDa6Hg99lEJBkXVYG2PqFksQpxERGswfL0lk4+7DvPbdFq/DMcaYKmUJohyjElswvEsMz3yxkd0Ha/eShMYY48sSRDlEhIcu7U5+YTFPzFvndTjGmFqmotN9AzzzzDMcPRq4qYEsQZyBhGb1ue38BP69IoulW/Z5HY4xphapswlCRMaIyHoRSRORaWWcc7WIpIpIioi85bP/JhHZ6L5uCmScZ+KOCzrSulEEf/hwDYXWYW2M8RPf6b6nTp3KU089Rf/+/TnnnHN46KGHADhy5Ajjxo2jV69e9OjRg3feeYfnnnuO7du3c8EFF3DBBRcEJLaATdYnIsHAC8AoIBNYKiJzfJcOFZFOwAPAYFXdLyLN3f1NgIeAJECBZe61+wMVb3kiw0J48JJE/uvN5bz5wzZuOi/eq1CMMYEybxrsXO3fMlv2hLFPlHnYd7rv+fPn8/777/Pjjz+iqowfP56vv/6a7OxsWrduzaeffgo4czQ1atSIp59+moULF9KsWTP/xuwKZA1iAJCmqumqmg/MBiaUOOc24IXjX/yqutvdPxpYoKr73GMLgDEBjPWMjO3RkiEdm/G/89ez53Ce1+EYY2qZ+fPnM3/+fPr06UPfvn1Zt24dGzdupGfPnixYsIDf/va3fPPNNzRq1KhK4gnkdN9tgAyf7UxgYIlzOgOIyGIgGHhYVT8r49o2JW8gIlOAKQBxcXF+C7wsIsLD4xMZ88w3/Pmzdfz5yl4Bv6cxpgqd5i/9qqCqPPDAA9x+++0/O7Z8+XLmzp3Lgw8+yIgRI/jjH/8Y8Hi87qQOAToBw4FrgZdFpPGZXqyqL6lqkqomxcTEBCbCEjo2j2LykATeTc5k+TbPWryMMbWE73Tfo0ePZubMmRw+fBiArKwsdu/ezfbt24mMjOT6669n6tSpLF++/GfXBkIgaxBZQFuf7Vh3n69M4AdVLQA2i8gGnISRhZM0fK9dFLBIz9JvRnTiw5VZPPRRCh/eMZjgIPE6JGNMDeU73ffYsWOZOHEi5557LgANGjTgjTfeIC0tjalTpxIUFERoaCgvvvgiAFOmTGHMmDG0bt2ahQsX+j22gE33LSIhwAZgBM4X/lJgoqqm+JwzBrhWVW8SkWbACqA3bsc00Nc9dTnQT1XLHGPqz+m+z8RHK7O4e/ZKHru8B9cNbFdl9zXG+JdN9+3BdN+qWgjcCXwOrAXeVdUUEZkuIuPd0z4H9opIKrAQmKqqe91E8ChOUlkKTD9dcvDC+F6tGZjQhKc+X8/+I/leh2OMMX4X0D4IVZ2rqp1VtYOqPubu+6OqznHfq6req6qJqtpTVWf7XDtTVTu6r1cDGWdFiAjTJ/TgUG4hT81f73U4xhjjd153UtdoXVpGcdO58bz94zZ+yszxOhxjTAXVlpU1T6civ6MliEq6Z1QnmtYP548fpVBcXPv/kRlT20RERLB3795anSRUlb179xIREXFW1wVyFFOd0DAilAfGduW+91bx/rJMru7ftvyLjDHVRmxsLJmZmWRnZ3sdSkBFREQQGxt7VtdYgvCDy/u04a0ft/HkZ+sY3b0ljSJDvQ7JGHOGQkNDSUhI8DqMasmamPwgKEiYPqE7+4/m8/QC67A2xtQOliD8pHvrRlw/qB2vL9lK6vaDXodjjDGVZgnCj+4b1YXGkWE8NGdNre7wMsbUDZYg/KhRZCi/HdOFpVv288GKkrOKGGNMzWIJws+u6teWXm0b8z9z13Eot8DrcIwxpsIsQfhZUJAwfXx39h7J45kvNnodjjHGVJgliADo1bYx1/Rvy6zvtrBhV+Cm4jXGmECyBKEKy/8Jx3L8WuzU0V1pEB7CHz+yDmtjTM1kCWLPRvjk/8G/boXiIr8V26R+GPeP7sKS9H188tMOv5VrjDFVxRJETGcY+2dIWwBfPurXoicOiKNHm4Y89ulajuQV+rVsY4wJNEsQAP0nQ7+b4du/wpp/+a3Y4CDhkfE92Hkwl+e+tA5rY0zNYgniuLFPQdtB8OEdsOMnvxXbr100V/aLZea3m0nbfdhv5RpjTKBZgjguJAyu/ifUi4bZ18GRPX4r+rdjuhIRGswjH6dYh7UxpsYIaIIQkTEisl5E0kRkWinHbxaRbBFZ6b5u9TlW5LN/TiDjPCGqBVzzJhzeBe/dDEX+edAtJiqce0d15puNe/g8ZadfyjTGmEALWIIQkWDgBWAskAhcKyKJpZz6jqr2dl+v+Ow/5rN/fCnXBUabvjD+OdjyDcx/0G/F3jCoHV1bRvHoJ2s5lu+/0VLGGBMogaxBDADSVDVdVfOB2cCEAN7Pf3pdA4PugB/+Dive8EuRIcFBPDK+O1k5x/jbojS/lGmMMYEUyATRBsjw2c5095V0hYj8JCLvi4jvcmwRIpIsIktE5LLSbiAiU9xzkv2+GtSo6ZAwzHlGIjPZL0UObN+Uy3q35h9fpbNlzxG/lGmMMYHidSf1x0C8qp4DLABe8znWTlWTgInAMyLSoeTFqvqSqiapalJMTIx/IwsOgatmQVQreOd6OOSfvoPfXdyN0GCxDmtjTLUXyASRBfjWCGLdfSeo6l5VzXM3XwH6+RzLcn+mA4uAPgGMtXSRTeCatyD3gJMkCvPKv6YczRtGcM/Izixcn81/1u72Q5DGGBMYgUwQS4FOIpIgImHANcApo5FEpJXP5nhgrbs/WkTC3ffNgMFAagBjLVvLHnDZi5C5FD69z5m7qZJuHhxPx+YNeOSTFHILrMPaGFM9BSxBqGohcCfwOc4X/7uqmiIi00Xk+Kiku0QkRURWAXcBN7v7uwHJ7v6FwBOq6k2CAOh+GQy9H1a8DktfKff08oQGBzF9fHcy9h3jH1+lVz4+Y4wJAKkt7eBJSUmanOyfzuRSFRfD7Gsh7Qu48SOIH1LpIu94azlfpO7ii3uH0bZJpB+CNMaYsyMiy9z+3p/xupO65ggKgl+8BNEJ8O6NkLOt0kX+/uJuBInw6CfeVY6MMaYsliDORkQjuPZt5wnr2ddB/tFKFde6cT1+M6Ij81N3sWi9dVgbY6oXSxBnq1knuGIG7FwNc35T6U7ryUMSaN+sPo98nEpeoXVYG2OqD0sQFdH5IhjxB1jzPnz3XKWKCg8J5qHx3dm85wivfLPZTwEaY0zlWYKoqCH3QuJl8MXDTsd1JQzrHMPo7i14/ss0NtsT1saYasISREWJwGV/g+aJ8P4k2LupUsX94ZJEwkODmDRrKfuP5PspSGOMqThLEJURVt+ZHlyCYfZEyDtU4aJioyN5+cYksvYf4/bXl1l/hDHGc5YgKis63pmzac9G+PftzvMSFdQ/vgn/e3Uvftyyj6nv/URxce14RsUYUzNZgvCH9sNg9GOw/lP46slKFTW+V2umju7CnFXbeXrBBj8FaIwxZy/E6wBqjYG/ctay/uoJZ/6mbpdWuKj/Gt6BbXuP8vzCNOKaRHJ1/7blX2SMMX5mNQh/EYFL/gpt+sEHv4LdaytRlPCny3swtFMzfvfBar7d6L/1sY0x5kxZgvCn0Aj45RtO5/XsiXBsf8WLCg7ihev60iGmAb9+Yxnrd1a8A9wYYyrCEoS/NWwNV78OORnO8Nfiio9GahgRysxb+hMRFsykWUvZfTDXj4EaY8zpWYIIhLiBMO4vsOlL50G6SmjTuB4zb+rPviP5TH4tmaP5hf6J0RhjymEJIlD63QT9b3Wm4vjpvUoV1TO2Ef93bR9Sth/grrdXUmTDX40xVcASRCCNfhzizoM5d8L2lZUqamRiCx66tDtfrN1l04MbY6pEQBOEiIwRkfUikiYi00o5frOIZIvISvd1q8+xm0Rko/u6KZBxBkxIGFz9T4hs5kwPfji7UsXddF48kwYnMOu7Lby62Cb2M8YEVsAShIgEAy8AY4FE4FoRSSzl1HdUtbf7esW9tgnwEDAQGAA8JCLRgYo1oBrEwDVvwNE98N5NzloSlfD7cd24KLEF0z9JZX7KTj8FaYwxPxfIGsQAIE1V01U1H5gNTDjDa0cDC1R1n6ruBxYAYwIUZ+C17gPjn4eti+Gzn1WkzkpwkPDsNX04p00j7p69kp8yc/wTozHGlBDIBNEGyPDZznT3lXSFiPwkIu+LyPFHhs/oWhGZIiLJIpKcnV255puAO+cqOO83sPQVWPZapYqqFxbMKzf1p0n9MCbNSiZzf+VWtjPGmNJ43Un9MRCvqufg1BLO6ptTVV9S1SRVTYqJiQlIgH418hHocCF8eh9s+6FSRcVEhTPrlv7kFRYxadZSDuZWrunKGGNKCmSCyAJ8JxGKdfedoKp7VTXP3XwF6Hem19ZIQcFw5UxoFAvv3gAHt1equE4tovjH9f1Izz7Cr99YRn5hxWeSNcaYkgKZIJYCnUQkQUTCgGuAOb4niEgrn83xwPEJjD4HLhKRaLdz+iJ3X81XLxqufRvyj8A710NB5Z6OPq9jM5644hwWp+3l9x+sRiu5RrYxxhwXsAShqoXAnThf7GuBd1U1RUSmi8h497S7RCRFRFYBdwE3u9fuAx7FSTJLgenuvtqheTe4/B+QtQw+vRcq+aV+Zb9Y7hrRifeWZfLCwjQ/BWmMqeuktvzFmZSUpMnJyV6HcXYWPu5MDz7mSRj0q0oVparc++4qPliRxbPX9GZC79LGAxhjzKlEZJmqJpV2zOtO6rpt2G+hyzj4/HeQ/lWlihIRnriiJwMSmjD1vZ/4cXPtqXAZY7xhCcJLQUFw+d+haUd472bYv6VSxYWHBPPSDf2IbVKPKa8nk5592C9hGmPqJksQXoto6HRaa5EzHcf+rZUqrnFkGLNuHkCwCLfMWsrew3nlX2SMMaWwBFEdNO3gDH/dsxH+ry98dAfsS69wcXFNI3n5piR2Hsjltn8mk1tQ8TUpjDF1lyWI6qLjSLhrBSRNhtXvw/8lwb9vd5JGBfSNi+aZX/ZmRUYO9727imKbItwYc5YsQVQnjdrAxX+Gu1fBoF9D6kfwfH94f3KF1rge27MVD4ztyqerd/Dnz9cHIGBjTG1mCaI6imoJox+De1bD4Lth/Tz427nw7o2wc/VZFXXb0PZcNzCOv3+1ibd+2BaggI0xtZEliOqsQQyMesRJFEPvg00L4e9D4O2JsH3FGRUhIjwyvjvDu8Twh4/WsGj97gAHbYypLSxB1AT1m8KIP8A9P8HwB2Drt/DScHjzKshYWu7lIcFBPD+xL11aRHHnWytYu+Ng4GM2xtR4liBqknrRMHwa3LMGLvwDZCbDjJHw+uWw9fvTXtogPISZN/enQXgIk2YtZeeBys0BZYyp/SxB1EQRDeH8+52mp1HTYcdP8OoYmHUJbP6mzLmdWjaKYObN/Tl4rIBJs5ZyOK+wigM3xtQkliBqsvAGTif2Path9P/Ang3w2iXw6ljY9GWpiSKxdUNeuK4v63cd4jdvLaewyKYIN8aUzhJEbRAWCefe4QyPHfsU5Gxzmp1eGQkb5v8sUQzv0pzpE7qzcH02j3ycalOEG2NKddoEISKXikg7n+0/isgqEZkjIgmBD8+cldB6MHCK88DdJX+Fw7vhraucDu11n56SKK4b2I7bh7Xn9SVbmfHtZu9iNsZUW+XVIB4DsgFE5BLgemASzsI/fw9saKbCQsIhaRLctRzGPw+5OTB7Ivx9KKR8CMVOs9JvR3dlXM9WPDZ3LfNW7/A0ZGNM9VNeglBVPeq+/wUwQ1WXqeorQA1YBLqOCw6FvjfAncucBYoKj8F7N8GL58Hq9wmimL9c3Ys+bRtzzzsrWbFtv9cRG2OqkfIShIhIAxEJAkYA//E5FhG4sIxfBYdAr2vgjh/hihmAwr8mwwsDiUh9j5ev702LhhHc+loy2/YeLbc4Y0zdUF6CeAZYCSQDa1U1GUBE+gDltkmIyBgRWS8iaSIy7TTnXSEiKiJJ7na8iBwTkZXuy5qz/CEoGHpeCb/+Hq56zWmK+uB2mr46mPcHboKiAm6Z9SMHjhZ4Hakxphood8lREWkDNAdWqWqxu68lEKaqZU7uIyLBwAZgFJCJs7b0taqaWuK8KOBTIAy4U1WTRSQe+ERVe5zpL1Ijlxz1WnExbJgHXz0JO1aRV78Nfzo4lk2tx/PMdQNp3tAqicbUdhVectQdwXRYVVeoarGIXCAizwITgZ3l3HcAkKaq6aqaD8wGJpRy3qPAk4A92lvVgoKg6ziY8hVMfI/wxq14NPgVnt95HSl/uZiVr0+jaP3ncDjb60iNMR4IKef4u8DlwAER6Q28BzwO9AL+Btx6mmvbABk+25nAQN8TRKQv0FZVPxWRqSWuTxCRFcBB4EFV/abkDURkCjAFIC4urpxfxZRJBDpfBJ1GwaYvCV36Fp02LaV12t8J2vSic07DWGjdG1r3OfmKbOJp2MaYwCovQdRT1e3u++uBmar6F7fTemVlbuyW8TRwcymHdwBxqrpXRPoBH4pId1U9ZZY5VX0JeAmcJqbKxGNwEkXHEUR1HEEDVeYuS+PDefNol7eeCcG7SNyVSvC6T06e37jdqQmjVS+o19iz8I0x/lVeghCf9xcCDwC4zU3llZ0FtPXZjnX3HRcF9AAWuWW1BOaIyHi3MzzPvdcyEdkEdMbpLDdVQEQYl9SJId3j+cv89YxfspWYBuFMvySW0U12IttXOlOOb18BqR+evLBJhxJJ4xwIj/Lq1zDGVMJpO6nd/oZWOH/Rjwc6q2qBiLQCPi6rY8O9NgSnk3oETmJYCkxU1ZQyzl8E3O92UscA+1S1SETaA98APVV1X1n3s07qwFqVkcPvPlhNyvaDnN85hkcndKdd0/rOwaP7TiaL7Stg+0o4mOleKdCs86lJo2VPZ3oQY4znTtdJXV6CEOCXOEniXVXNcvf3AZqr6ufl3PhinKGywTjNU4+JyHQgWVXnlDh3EScTxBXAdKAAKAYeUtWPT3cvSxCBV1hUzOtLtvKX+RsoKCrmzgs6MmVYe8JDgn9+8uHdTqI4kTSWw+FdzjEJgphubsLoDa37QovuEGqjpoypahVOED4FJADd3c1UVU33Y3x+YQmi6uw8kMujn6Ty6eodtI+pz58u68F5HZqVf+HBHSVqGsvh6F7nWFAINE88tabRPBFCwgL7yxhTx1WmBtEQeAXoB6xyd/cGlgGTS3Yae8kSRNVbuH43f/xoDRn7jvGLPm343bhuNGsQfuYFqMKBzBJJY4UzdxRAcBi06HEyYbTpC826OE+GG2P8ojIJYhawBZju85CcAH8AOqrqjX6PtoIsQXgjt6CI579M4x9fbyIyLITfjunKNf3bEhRU7iCG0qnC/i1O7eJ4E9WOVZDn/i0SUs/p+PataTTt6Dwlbow5a5VJEBtVtdPZHvOCJQhvpe0+xIMfrmFJ+j76xjXmT5f1JLF1Q/8UXlwM+9JPbZrasQoK3Hmjwho4Q2x9k0Z0gvMgoDHmtAKVINJUtaOfYqw0SxDeU1U+WJHFY5+uJedYAZMGx3PPyM7UDw9Ak1BxkbOCnm/T1M7VUOg+kB/eCFqXSBqN2znPehhjTqhMgngN2AQ8qj4nisgfcIa83uDvYCvKEkT1kXM0nyc/W8/bP26jVaMIHh7fnYsSW3AGz85UTlEBZK8rkTTWQLE7+WC96FMTRus+0LCNJQ1Tp1W2k3oG0JeTT073BlbgdFIf8GuklWAJovpZtnUfv/9gDet2HmJkt+Y8PL47sdFV/PxDYR7sTj2ZMLJWONta5Byv3/znSSOqRdXGWB2oOv08R/Y4r6N74Ei2+9p78n3uAWjc1hmm3LyrM9KsSQcbbVaD+WOYawcg0d1MVdVNInKPqj7jvzArxxJE9VRQVMyrizfz1wUbAbh7ZCcmD0kgNNjD/oGCY07NwremsWc9OOMwIKp1iaapts5yriH1nGc1QurVjJFU+UdOfuEfyfb50vfZdyTbGWp8JBuK8ksvJywK6jeD+jHOU/E5W50+oeOfV1CIM1Agxk0Yzbs6CaRJ+5rxOdVUxcVwaIczqEOLIWFohYqpdIIoo9BtqlptZsizBFG9ZeUc45E5KcxP3UXnFg147PKe9I+vRpP95R12+jB8k8bejWWfHxQCoZEQEnEyaYRG+Oyrd/Kn7/tT9tU7w2vrOR3uBbnul3x5X/puDaCgjMWfQiOdL/xI90u/fgzUb+r8PLGv2clzSnuAsSDX6QPKXge717o/U2H/VsD9TgkOc56ij+l6srYR0xWi423U2ZnKP+IkgJKvfZudRH08qbfuA1MWVegWgUoQGaratvwzq4YliJphQeouHp6TQlbOMa5OimXa2G40qV9NmydyD8COn5wnwAtznZpHwbGT7wtznS/hglxnOdcTP4+Vse8YJ748z1ZQ6Mm+lJKCw09+oZ/4kvfZLrkvrH6FP5Jy5R91amO71/okjnVwwGfpmJAIJ3H41jaad4NGbeveyLPiYji88+df/sffH9l96vnhDZ0EW/LVtIPzswKsBmGqlaP5hTz7n43M+GYzUREhPHBxN67qFxv4TmyvqTp/8VUkuRTmOsN5T/x1HwORTU82+1T3zy7vEGSvP7W2sXsdHNp+8pzQ+hDTxUkWzbud7Oeo6QMJ8o86f+2X/PLfv8XZf3zkHTjT0DSMheh2zhd+kwSfRJDgDLTw82dRmU7qQ5T+J4/gTAVebRoYLUHUPOt2HuTBD9aQvHU/A+Kb8KfLe9C5hc38WqccyynRTOXWPHz/cg5veLKZ6nhto2kHp+YkQU5zlYjzXoJ99gX5vAKYYFSdWmbJL//97vbxOciOC4uCJvGl1AQSnFpUFXf4B6QGUd1YgqiZiouV95Zl8Pi8dRzOLeS289tz14WdqBdmbdR12tF9btJwE8budc7743N3nTUpkTSOJ5KgUvYFn0wqpSYcNyEFBbu1g21OTc/3Xo1i3S/9ds4X//EEEB3vLLRVjWpEliBMtbf3cB6Pz1vH+8syiY2ux/QJ3bmwax0cbmpO73C22xm+BYoLndE7Wuw8OHn8vbrvi4t/vu9n55bcd/ynlrLv+Hl6cl9w2Kk1gCYJTnIIOYs5yTxmCcLUGEvS9/Lgh2tI232Ykd1aMHV0F7q0tGYnYwLldAmijg0ZMNXdoPZNmXvXUP57TBeWpO9lzLNfc/fsFWzec8Tr0Iypc6wGYaqt/Ufy+cfX6cz6bjMFRcqVfWO5a2Qn2jSu53VoxtQantUgRGSMiKwXkTQRmXaa864QERWRJJ99D7jXrReR0YGM01RP0fXDmDa2K1//9wXcMKgdH6zI4oKnFvHQR2vYfTC3/AKMMZUSsBqEiATjrEk9CsjEWZP6WlVNLXFeFPApEAbc6S45mgi8DQwAWgNf4EwOWFTW/awGUfttzznG/32ZxnvJGQQHCTedF8+vhnWovg/aGVMDeFWDGACkqWq6quYDs4EJpZz3KPAk4Psn4QRgtqrmqepmIM0tz9RhrRvX4/Ff9OQ/9w1jXM9WvPxNOkOf/JKn56/nwLEynjI2xlRYIBNEGyDDZzvT3XeCiPQF2qrqp2d7rXv9FBFJFpHk7Oxs/0Rtqr12Tevz9C97M/+e8xnepTnPfZnG+X9eyAsL0ziSV+h1eMbUGp6NYhKRIOBp4L6KlqGqL6lqkqomxcTE+C84UyN0ahHFC9f15ZPfDCGpXTRPfb6eYU8tZMa3m8ktKLM10hhzhgKZILIA38n8Yt19x0UBPYBFIrIFGATMcTuqy7vWmBN6tGnEjJv7869fn0eXllE8+kkqw59axBtLtpJfWOx1eMbUWIHspA7B6aQegfPlvhSYqKopZZy/CLjf7aTuDrzFyU7q/wCdrJPanInvNu3hL/M3sGzrfto2qcfdIzpzWe/WhHi5BoUx1ZQnndSqWgjcCXwOrAXeVdUUEZkuIuPLuTYFeBdIBT4D7jhdcjDG13kdmvH+r87l1Vv606heKPe/t4qLnvmaj1dtp7i4djz3Y0xVsAflTK2mqnyesounF6xnw67DdG0ZxX0XdWFkt+a1f3pxY86ATbVh6iwRYUyPlsy7+3yevaY3uQVF3PbPZC7723d8szGb2vIHkjGBYAnC1AnBQcKE3m1YcO8wnryiJ3sO5XHDjB/55UtLWLpln9fhGVMtWROTqZPyCot4Z2kG//dlGtmH8ji/cwz3X9SZc2Ibex2aMVXKpvs2pgzH8ot4fckWXly0if1HC7gosQX3XtSZri0beh2aMVXCEoQx5TicV8jMbzfz8tfpHM4v5NJzWnPPyE60j2ngdWjGBJQlCGPOUM7RfF76Op1XF28hr7CIK/rG8psLOxHXNNLr0IwJCEsQxpylPYfzeHHRJl5fspXComJGJbZg8pD29I+PtuGxplaxBGFMBe06mMs/v9/Cmz9sI+doAT3bNGLSkHjG9WxNWIgNAjQ1nyUIYyrpWH4RH6zIYubizaTtPkzzqHBuPLcdEwe2s/UoTI1mCcIYPykuVr7emM3MxVv4ekM24SFB/KJvGyYNTqBTiyivwzPmrJ0uQYRUdTDG1GRBQcLwLs0Z3qU5G3cdYubiLfx7eSZv/5jB0E7NmDwkgWGdY6yfwtQKVoMwppL2HcnnrR+28s/vt7L7UB4dmzfglsHx/KJPLPXCgr0Oz5jTsiYmY6pAfmExn67ezoxvN7Mm6yCNI0OZOCCOG8+Np2WjCK/DM6ZUliCMqUKqytIt+5nxbTrzU3cRLMK4c1oxeUiCTeVhqh3rgzCmCokIAxKaMCChCdv2HuW177fwztIMPlq5naR20UweksBF3VsSHGT9FKZ6sxqEMVXgUG4B7yZnMuu7zWTsO0ZsdD1uPi+eq/u3pWFEqNfhmTrMsyYmERkDPAsEA6+o6hMljv8KuAMoAg4DU1Q1VUTicVahW++eukRVf3W6e1mCMDVBUbGyIHUXMxdv5sfN+6gfFsxVSW25ZXA87ZrW9zo8Uwd5kiBEJBhnTepRQCbOmtTXqmqqzzkNVfWg+3488F+qOsZNEJ+oao8zvZ8lCFPTrM48wMzFm/nkp+0UFisju7Vg8pAEBiY0sWGypsp41QcxAEhT1XQ3iNnABJx1pgE4nhxc9YHa0d5lzBnoGduIv/6yN9PGduX177fy5g9bWZC6i+6tGzJpcAKX9GpFeIgNkzXeCeRkMm2ADJ/tTHffKUTkDhHZBPwZuMvnUIKIrBCRr0RkaADjNMZTLRpGcP/oLnz/wAge/0VP8guLue+9VQx5ciHP/Wcjew/neR2iqaMC2cR0JTBGVW91t28ABqrqnWWcPxEYrao3iUg40EBV94pIP+BDoHuJGgciMgWYAhAXF9dv69atAfldjKlKqso3G/cwc/FmFq3PJiwkiMt7t2Hy0AQ623Qexs+86oM4F3hYVUe72w8AqOrjZZwfBOxX1UalHFsE3K+qZXYyWB+EqY3Sdp+cziO3oJhhnWO4bWh7Bndsav0Uxi9OlyAC2cS0FOgkIgkiEgZcA8wpEVgnn81xwEZ3f4zbyY2ItAc6AekBjNWYaqlj8yj+5/KefD9tBPeN6kzK9oNcP+MHLn7uW/61LJP8wmKvQzS1WKCHuV4MPIMzzHWmqj4mItOBZFWdIyLPAiOBAmA/cKeqpojIFcB0d38x8JCqfny6e1kNwtQFuQVFzFm5nZe/SWfj7sO0aBjOzeclMHFAHI0i7XkKc/Zsqg1jahlV5asN2bz8TTqL0/YSGRbM1UltmTwkgbZNbHlUc+YsQRhTi6VsP8CMbzYzZ9V2ilUZ26MVtw5NoE9ctNehmRrAEoQxdcCOA8eY9d0W3vphG4dyC0lqF82tQ9szKrGFzftkymQJwpg65HBeIe8uzWDm4s1k7j9GfNNIJg9J4Mp+bW19CvMzliCMqYMKi4r5PGUXL32TzqqMHBpHhnL9wHbceF47mkfZ+hTGYQnCmDpMVUneup+Xv05nwdpdhAYFcVmf1tw6tL09eGdsPQhj6jIRoX98E/rHN2HzniPM/HYz7y3L4N3kTHvwzpyW1SCMqYP2HcnnzSVbee37rew5nEe3Vg25dUgCl/ZqTVhIIJ+fNdWNNTEZY0plD94ZSxDGmNNSVRZtyOYVe/CuzrEEYYw5Y/bgXd1iCcIYc9bKevBuRLfmhAZbP0VtYQnCGFNhxx+8m/HtZrJyjtGoXigXJbZgbM+WDO7YzFa9q+EsQRhjKq2wqJiF67OZt3oHC1J3cSivkKjwEEYmtmBMj5YM6xxDRKgli5rGnoMwxlRaSHAQoxJbMCqxBXmFRXyXtpe5q3ewYO0uPliRRWRYMBd2bc7YHq24oGsMkWH29VLT2X9BY8xZCw8J5oKuzbmga3MKiopZkr6Xuat3Mj9lJ5/8tIOI0CCGd27O2J4tubBrc6IibMhsTWRNTMYYvykqVn7cvI/P1uxg3pqd7D6UR1hwEEM7NWNsz1aM6tbCnq+oZqwPwhhT5YqLleXb9jNvzU7mrd7B9gO5hAQJgzs2Y2yPllzUvSVN6od5HWad51mCEJExwLM4S46+oqpPlDj+K+AOoAg4DExR1VT32APAZPfYXar6+enuZQnCmOpLVVmVeYB5a3Ywb/VOtu07SnCQMKh9E8b0aMXo7i1shlmPeJIgRCQY2ACMAjKBpcC1xxOAe05DVT3ovh8P/JeqjhGRROBtYADQGvgC6KyqRWXdzxKEMTWDqpKy/SCfrdnJ3DU7SM8+ggj0b9eEsT1bMqZHS1o1qud1mHWGV6OYBgBpqpruBjEbmACcSBDHk4OrPnA8W00AZqtqHrBZRNLc8r4PYLzGmCogIvRo04gebRpx30Wd2bj7MHNXOzWLRz5O5ZGPU+kT15iLe7RiTI+WNtWHhwKZINoAGT7bmcDAkieJyB3AvUAYcKHPtUtKXNumlGunAFMA4uLi/BK0MabqiAidW0TRuUUU94zszKbsw07NYvUOHpu7lsfmrqVnm0aM7dmSsT1akdCsvtch1ymBbGK6Ehijqre62zcAA1X1zjLOnwiMVtWbROR5YImqvuEemwHMU9X3y7qfNTEZU7ts23uUeWt2MHfNTlZl5ADQtWUUF/dsxcU9W9GxeQNvA6wlvGpiygLa+mzHuvvKMht4sYLXGmNqmbimkdw+rAO3D+tAVs4xPnNHQ/31iw08vWAD53VoyqTBCVzYtTlBQbbYUSAEsgYRgtNJPQLny30pMFFVU3zO6aSqG933lwIPqWqSiHQH3uJkJ/V/gE7WSW2M2XUwlw9WZPHad1vYcSCXhGb1uWVwPFf0jaV+uD37e7a8HOZ6MfAMzjDXmar6mIhMB5JVdY6IPAuMBAqA/cCdxxOIiPwemAQUAveo6rzT3csShDF1S0FRMfPW7GTGt5tZlZFDw4gQrh0Yx03nxtO6sY2COlP2oJwxplZbtnU/M7/dzLw1OxARxvZoyeQhtobFmbDJ+owxtVq/dtH0axdN5v6jvPbdFmb/mMEnP+2gb1xjJg9pz+juLQixNSzOmtUgjDG1zuG8Qt5PzuDV77awde9R2jSux03nteOX/eNoVM/mgvJlTUzGmDqpqFj5z9pdzPh2Mz9s3kdkWDBX9YvllsEJxNszFYAlCGOMYU3WAWYu3szHq7ZTWKyM6NqCyUMSGNS+CSJ1d5isJQhjjHHtPpjLG0u28sYP29h3JJ/EVg2ZNCSBS3u1qpPLp1qCMMaYEnILivhwRRYzF29mw67DxESFc8Ogdlw3MI6mDcK9Dq/KWIIwxpgyqCrfbNzDzMWbWbQ+m7CQIC7v3YZJQxLo0jLK6/ACzoa5GmNMGUSE8zvHcH7nGNJ2H2Lm4i38e3km7yRnMLRTMyYNSWBYp5g6OZ2H1SCMMaaE/UfyeevHbbz23RZ2H8qjQ0x9bhmcwBV9Y6kXVrv6KayJyRhjKiC/sJi5q3cw49vNrM46QOPIUCYOiOPGc+Np2ah2rIBnCcIYYypBVVm6ZT8zvk1nfuougkUYd04rrh/Ujl6xjQkLqblPaVsfhDHGVIKIMCChCQMSmrBt71FmfbeFd5Mz+GjldsJCgujRuiF94qLpE9eYPnHRtG4UUSuerbAahDHGVMCh3AK+3biHFRk5rNi2n58yD5BXWAxA86hw+vokjJ5tGlXbvgurQRhjjJ9FRYQytmcrxvZsBTjTj6/bcYjl2/azYtt+VmTk8FnKTgCCg4RuraJOJo220bRrGlntaxlWgzDGmADZeziPlRk5btLIYVVGDkfynXXPoiNDnWapto3p2y6ac2IbERVR9RMJWg3CGGM80LRBOCO6tWBEtxaAM3ngxt2HWLHNaZZavi2HL9ftBkAEOjePcpulGtM3LpoOMQ08ff4i0CvKjQGexVlR7hVVfaLE8XuBW3FWjcsGJqnqVvdYEbDaPXWbqo4/3b2sBmGMqYkOHCtgVUaOkzQynJrGgWMFAESFh9A7rjF92jp9Gb3bNia6fphf7+/JMFcRCcZZk3oUkImzJvW1qprqc84FwA+qelREfg0MV9VfuscOq2qDM72fJQhjTG2gqqTvOXKilrFiWw7rdh6k2P2qbt+svpM03Oapri2jKrUYkldNTAOANFVNd4OYDUwATiQIVV3oc/4S4PoAxmOMMdWeiNAhpgEdYhpwZb9YAI7kFbI668CJvoyvN2Tz7+VZANQLDWZEt+Y8P7Gv32MJZIJoA2T4bGcCA09z/mRgns92hIgk4zQ/PaGqH/o9QmOMqQHqh4cwqH1TBrVvCji1jMz9x1iRkcPyrfuJDNAQ2mrRSS0i1wNJwDCf3e1UNUtE2gNfishqVd1U4ropwBSAuLi4KovXGGO8JCK0bRJJ2yaRjO/VOmD3CeTz4VlAW5/tWHffKURkJPB7YLyq5h3fr6pZ7s90YBHQp+S1qvqSqiapalJMTIx/ozfGmDoukAliKdBJRBJEJAy4Bpjje4KI9AH+gZMcdvvsjxaRcPd9M2AwPn0XxhhjAi9gTUyqWigidwKf4wxznamqKSIyHUhW1TnAU0AD4D33icLjw1m7Af8QkWKcJPaE7+gnY4wxgWdPUhtjTB12umGuNXeOWmOMMQFlCcIYY0ypLEEYY4wplSUIY4wxpao1ndQikg1srUQRzYA9fgqnprPP4lT2eZzKPo+TasNn0U5VS32QrNYkiMoSkeSyevLrGvssTmWfx6ns8ziptn8W1sRkjDGmVJYgjDHGlMoSxEkveR1ANWKfxans8ziVfR4n1erPwvogjDHGlMpqEMYYY0plCcIYY0yp6nyCEJExIrJeRNJEZJrX8XhJRNqKyEIRSRWRFBG52+uYvCYiwSKyQkQ+8ToWr4lIYxF5X0TWichaETnX65i8JCL/z/3/ZI2IvC0iEV7H5G91OkGISDDwAjAWSASuFZFEb6PyVCFwn6omAoOAO+r45wFwN7DW6yCqiWeBz1S1K9CLOvy5iEgb4C4gSVV74CxpcI23UflfnU4QwAAgTVXTVTUfmA1M8Dgmz6jqDlVd7r4/hPMF0MbbqLwjIrHAOOAVr2Pxmog0As4HZgCoar6q5ngalPdCgHoiEgJEAts9jsfv6nqCaANk+GxnUoe/EH2JSDzOMq8/eByKl54B/hso9jiO6iAByAZedZvcXhGR+l4H5RV3SeT/BbYBO4ADqjrf26j8r64nCFMKEWkA/Au4R1UPeh2PF0TkEmC3qi7zOpZqIgToC7yoqn2AI0Cd7bMTkWic1oYEoDVQX0Su9zYq/6vrCSILaOuzHevuq7NEJBQnObypqv/2Oh4PDQbGi8gWnKbHC0XkDW9D8lQmkKmqx2uU7+MkjLpqJLBZVbNVtQD4N3CexzH5XV1PEEuBTiKSICJhOJ1MczyOyTPiLAw+A1irqk97HY+XVPUBVY1V1Xicfxdfqmqt+wvxTKnqTiBDRLq4u0YAdXmd+G3AIBGJdP+/GUEt7LQP8ToAL6lqoYjcCXyOMwphpqqmeByWlwYDNwCrRWSlu+93qjrXu5BMNfIb4E33j6l04BaP4/GMqv4gIu8Dy3FG/62gFk67YVNtGGOMKVVdb2IyxhhTBksQxhhjSmUJwhhjTKksQRhjjCmVJQhjjDGlsgRhTBlEpEhEVvq8prn7F7kzAK8SkcXHnw0QkTARecadGXijiHzkzud0vLyWIjJbRDaJyDIRmSsinUUkXkTWlLj3wyJyv/t+kIj84MawVkQersKPwdRhdfo5CGPKcUxVe5dx7DpVTRaRKcBTwHjgf4AooIuqFonILcC/RWSge80HwGuqeg2AiPQCWnDqfGCleQ24WlVXuTMQdynnfGP8whKEMZXzNXCPiETiPDiWoKpFAKr6qohMAi4EFChQ1b8fv1BVV8GJiRFPpznOhHC4ZdflJ5hNFbIEYUzZ6vk8UQ7wuKq+U+KcS4HVQEdgWymTGyYD3d33p5v4r0OJe7XEmS0U4K/AehFZBHyGUwvJPdNfwpiKsgRhTNlO18T0pogcA7bgTEERXcl7bfK9l28/g6pOF5E3gYuAicC1wPBK3s+YclmCMKZirlPV5OMbIrIPiBORKHexpeP6AceXK72yojdT1U3AiyLyMpAtIk1VdW9FyzPmTNgoJmP8QFWP4HQmP+12JCMiN+KsNPal+wp3O7Vxj58jIkPLK1tExrkzhgJ0AoqAHP/+Bsb8nCUIY8pWr8Qw1yfKOf8BIBfYICIbgauAy9UFXA6MdIe5pgCPAzvPII4bcPogVgKv49Reiir6Sxlzpmw2V2OMMaWyGoQxxphSWYIwxhhTKksQxhhjSmUJwhhjTKksQRhjjCmVJQhjjDGlsgRhjDGmVP8f61+P9kzCqMsAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training ended\n" + ] + } + ], + "source": [ + "train(train_dataload,test_dataload)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/LSTM-classifier-Task-2.ipynb b/LSTM-classifier-Task-2.ipynb new file mode 100644 index 0000000..d211c56 --- /dev/null +++ b/LSTM-classifier-Task-2.ipynb @@ -0,0 +1,367 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "import torch\n", + "import pickle\n", + "import datetime\n", + "import numpy as np\n", + "import warnings\n", + "\n", + "import pandas as pd\n", + "from pandas.core.common import SettingWithCopyWarning\n", + "\n", + "warnings.simplefilter(action=\"ignore\", category=SettingWithCopyWarning)\n", + "\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "\n", + "from matplotlib import pyplot as plt\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", + "if device == 'cuda':\n", + " print(\"=================================\")\n", + " print(\"GPU found\")\n", + " print(\"Using GPU at cuda:\",torch.cuda.current_device())\n", + " print(\"=================================\")\n", + " print(\" \")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model1 = torch.load(\"~/NNTI-WS2021-NLP-Project/Project_files/Hindi/epoch_300.pt\")\n", + "\n", + "w1 = model1[\"w1.weight\"].T\n", + "w2 = model1[\"w2.weight\"]\n", + "\n", + "cleandata = pd.read_pickle(\"~/NNTI-WS2021-NLP-Project/Project_files/Hindi/hindi_corpus_cleaned.pkl\")\n", + "word_index = pd.read_pickle(\"~/NNTI-WS2021-NLP-Project/Project_files/Hindi/word_index.pkl\")\n", + "index_word = pd.read_pickle(\"~/NNTI-WS2021-NLP-Project/Project_files/Hindi/index_word.pkl\")\n", + "V = pd.read_pickle(\"~/NNTI-WS2021-NLP-Project/Project_files/Hindi/vocab.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_csv('https://raw.githubusercontent.com/SouravDutta91/NNTI-WS2021-NLP-Project/main/data/hindi_hatespeech.tsv',sep='\\t')\n", + "text = data[['text','task_1']]\n", + "text['text'] = cleandata['text'].apply(lambda x: x.split())\n", + "text['label'] = text['task_1'].apply(lambda x: 1 if x == 'HOF' else 0)\n", + "max_len = text.text.str.len().max()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "def tag_count(input):\n", + " hcount,ncount = 0,0\n", + " for tag in input:\n", + " if tag == 1:\n", + " hcount+=1\n", + " else:\n", + " ncount+=1\n", + " return hcount,ncount\n", + "\n", + "word_index[''] = len(V)\n", + "index_word[len(V)] = ''\n", + "\n", + "\n", + "def get_word_embedding(input):\n", + " index = word_index[input]\n", + " return w1[index]\n", + "\n", + "def encode(corpus):\n", + " sent_idx = []\n", + " i = 0\n", + " for sentence in corpus:\n", + " sentence += [''] * (max_len - len(sentence))\n", + " idx = [word_index[word] for word in sentence]\n", + " sent_idx.append(idx)\n", + " i+= 1\n", + " return np.array(sent_idx)\n", + "\n", + "\n", + "from itertools import islice\n", + "def take(n, iterable):\n", + " \"Return first n items of the iterable as a list\"\n", + " return list(islice(iterable, n))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_ , emb_size = w1.shape\n", + "\n", + "def matrix_embeddings():\n", + " embedding_matrix = np.random.uniform(-0.25, 0.25, (len(word_index), emb_size))\n", + " embedding_matrix[word_index['']] = np.zeros((emb_size,))\n", + "\n", + " for word,i in take(len(V),word_index.items()):\n", + " temp = get_word_embedding(word)\n", + " if temp is not None:\n", + " embedding_matrix[i] = temp.cpu()\n", + "\n", + " return embedding_matrix\n", + "\n", + "encoded_text = encode(text.text)\n", + "labels = np.array(text['label'])\n", + "embeds = torch.tensor(matrix_embeddings())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.DataFrame({'encoded':list(encoded_text), 'label':labels})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "xtrain,xtest,ytrain,ytest = train_test_split(encoded_text,labels,shuffle=True,test_size=0.1,random_state=45)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 10\n", + "\n", + "from torch.utils.data import (TensorDataset, DataLoader, RandomSampler,SequentialSampler)\n", + "\n", + "def get_dataloader(traindata,trainlabels,testdata,testlabels,batchsize):\n", + " \n", + " traindata = torch.tensor(traindata).float()\n", + " trainlabels = torch.tensor(trainlabels)\n", + " testdata = torch.tensor(testdata).float()\n", + " testlabels = torch.tensor(testlabels)\n", + "\n", + " test = TensorDataset(testdata,testlabels)\n", + " test_dataload = DataLoader(test,sampler=RandomSampler(test),batch_size=batchsize,drop_last=True)\n", + " train = TensorDataset(traindata,trainlabels)\n", + " train_dataload = DataLoader(train,sampler=RandomSampler(train),batch_size=batchsize,drop_last=True)\n", + "\n", + " return train_dataload,test_dataload\n", + "\n", + "train_dataload,test_dataload = get_dataloader(xtrain, ytrain,xtest,ytest,batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class lstmmodel(nn.Module):\n", + " def __init__(self, vocab_size, output_size, embedding_dim, hidden_dim, n_layers, drop_prob=0.5):\n", + " super().__init__()\n", + " \n", + " self.output_size = output_size\n", + " self.n_layers = n_layers\n", + " self.hidden_dim = hidden_dim\n", + " \n", + " #Pretrained Embeddings\n", + " self.embedding = nn.Embedding.from_pretrained(embeds.type(torch.float32),freeze=True)\n", + " #LSM Layer\n", + " self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=drop_prob, batch_first=True)\n", + " \n", + " #Dropout Layer\n", + " self.dropout = nn.Dropout(0.3)\n", + " \n", + " #Fully connected Layer\n", + " self.fc = nn.Linear(hidden_dim, output_size)\n", + " \n", + " #sigmoid Layer\n", + " self.sig = nn.Sigmoid()\n", + "\n", + " def forward(self, x, hidden):\n", + " batch_size = x.size(0)\n", + " x = x.long()\n", + " embed = self.embedding(x)\n", + " lstm_out, hidden = self.lstm(embed.type(torch.float32), hidden)\n", + " lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)\n", + " out = self.dropout(lstm_out)\n", + " out = self.fc(out)\n", + " sig_out = self.sig(out)\n", + "\n", + " sig_out = sig_out.view(batch_size, -1)\n", + " sig_out = sig_out[:, -1]\n", + "\n", + " return sig_out, hidden\n", + " \n", + " def init_hidden(self, batch_size):\n", + " weight = next(self.parameters()).data\n", + "\n", + " hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda(), \n", + " weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda())\n", + "\n", + " \n", + " return hidden" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "vocab_size = embeds.shape[0]\n", + "output_size = 2\n", + "embedding_dim = embeds.shape[1]\n", + "hidden_dim = 256\n", + "n_layers = 1\n", + "\n", + "model = lstmmodel(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)\n", + "\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# loss and optimization functions\n", + "lr=0.05\n", + "criterion = nn.BCELoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = model.to(device)\n", + "criterion = criterion.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 10\n", + "counter = 0 \n", + "print_every = 50\n", + "clip = 5\n", + "\n", + "model.cuda()\n", + "\n", + "model.train()\n", + "for e in range(epochs):\n", + " tot = []\n", + " h = model.init_hidden(batch_size)\n", + " for inputs, labels in train_dataload:\n", + "\n", + " counter += 1\n", + " \n", + " inputs, labels = inputs.type(torch.float64).cuda(), labels.cuda()\n", + " h = tuple([each.data for each in h])\n", + " \n", + " model.zero_grad()\n", + " \n", + " output, h = model(inputs, h)\n", + " loss = criterion(output.squeeze(), labels.float())\n", + " loss.backward()\n", + " nn.utils.clip_grad_norm(model.parameters(), clip)\n", + " optimizer.step()\n", + " tot.append(loss.item())\n", + " if counter % print_every == 0:\n", + " print(\"Epoch: {}/{}...\".format(e+1, epochs),\n", + " \"Step: {}...\".format(counter),\n", + " \"Loss: {:.6f}...\".format(loss.item()))\n", + "plt.plot(tot)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_losses = []\n", + "num_correct = 0\n", + "\n", + "\n", + "h = model.init_hidden(10)\n", + "model.eval()\n", + "\n", + "for inputs, labels in test_dataload:\n", + " h = tuple([each.data for each in h])\n", + " \n", + " inputs, labels = inputs.cuda(), labels.cuda()\n", + " \n", + " output, h = model(inputs, h)\n", + " test_loss = criterion(output.squeeze(), labels.float())\n", + " test_losses.append(test_loss.item())\n", + " pred = torch.round(output.squeeze())\n", + " correct_tensor = pred.eq(labels.float().view_as(pred))\n", + " correct = np.squeeze(correct_tensor.cpu().numpy())\n", + " num_correct += np.sum(correct)\n", + "\n", + "print(\"Test loss: {:.3f}\".format(np.mean(test_losses)))\n", + "\n", + "test_acc = (num_correct)/len(test_dataload.dataset) * 100\n", + "print(\"Test accuracy: {:.3f}\".format(test_acc))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/Project_files/.ipynb_checkpoints/Task1_Word_Embeddings-checkpoint.ipynb b/Project_files/.ipynb_checkpoints/Task1_Word_Embeddings-checkpoint.ipynb new file mode 100644 index 0000000..2a856b7 --- /dev/null +++ b/Project_files/.ipynb_checkpoints/Task1_Word_Embeddings-checkpoint.ipynb @@ -0,0 +1,1394 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "_VZXi_KGi0UR" + }, + "source": [ + "# Task 1: Word Embeddings (10 points)\n", + "\n", + "This notebook will guide you through all steps necessary to train a word2vec model (Detailed description in the PDF)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "48t-II1vkuau" + }, + "source": [ + "## Imports\n", + "\n", + "This code block is reserved for your imports. \n", + "\n", + "You are free to use the following packages: \n", + "\n", + "(List of packages)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "4kh6nh84-AOL" + }, + "outputs": [], + "source": [ + "import re,string\n", + "\n", + "import os\n", + "import time\n", + "import torch\n", + "import pickle\n", + "import datetime\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from pathlib import Path\n", + "\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader\n", + "from torch.autograd import Variable\n", + "\n", + "from tqdm.auto import tqdm, trange\n", + "from matplotlib import pyplot as plt\n", + "\n", + "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]='2,3,4,5'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=================================\n", + "GPU found\n", + "Using GPU at cuda: 0\n", + "=================================\n", + " \n" + ] + } + ], + "source": [ + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "if device == 'cuda':\n", + " print(\"=================================\")\n", + " print(\"GPU found\")\n", + " print(\"Using GPU at cuda:\",torch.cuda.current_device())\n", + " print(\"=================================\")\n", + " print(\" \")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NWmk3hVllEcU" + }, + "source": [ + "# 1.1 Get the data (0.5 points)\n", + "\n", + "The Hindi portion HASOC corpus from [github.io](https://hasocfire.github.io/hasoc/2019/dataset.html) is already available in the repo, at data/hindi_hatespeech.tsv . Load it into a data structure of your choice. Then, split off a small part of the corpus as a development set (~100 data points).\n", + "\n", + "If you are using Colab the first two lines will let you upload folders or files from your local file system." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "id": "XtI7DJ-0-AOP" + }, + "outputs": [], + "source": [ + "data = pd.read_csv('https://raw.githubusercontent.com/SouravDutta91/NNTI-WS2021-NLP-Project/main/data/hindi_hatespeech.tsv',sep='\\t')\n", + "data_text = data['text']\n", + "textd = data_text[:100]" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
text_idtexttask_1task_2task_3
0hasoc_hi_5556बांग्लादेश की शानदार वापसी, भारत को 314 रन पर ...NOTNONENONE
1hasoc_hi_5648सब रंडी नाच देखने मे व्यस्त जैसे ही कोई #शांती...HOFPRFNUNT
2hasoc_hi_164तुम जैसे हरामियों के लिए बस जूतों की कमी है शु...HOFPRFNTIN
3hasoc_hi_3530बीजेपी MLA आकाश विजयवर्गीय जेल से रिहा, जमानत ...NOTNONENONE
4hasoc_hi_5206चमकी बुखार: विधानसभा परिसर में आरजेडी का प्रदर...NOTNONENONE
5hasoc_hi_5121मुंबई में बारिश से लोगों को काफी समस्या हो रही हैNOTNONENONE
6hasoc_hi_7142Ahmed's dad:-- beta aaj teri mammy kyu nahi ba...NOTNONENONE
7hasoc_hi_43215 लाख मुसलमान उर्स में, अजमेर की दरगाह पर आते ...NOTNONENONE
8hasoc_hi_4674Do mahashaktiyan mili hain, charo taraf khusi ...NOTNONENONE
9hasoc_hi_1637Chants of 'Jai Sri Ram' as Owaisi takes oath: ...NOTNONENONE
\n", + "
" + ], + "text/plain": [ + " text_id text task_1 \\\n", + "0 hasoc_hi_5556 बांग्लादेश की शानदार वापसी, भारत को 314 रन पर ... NOT \n", + "1 hasoc_hi_5648 सब रंडी नाच देखने मे व्यस्त जैसे ही कोई #शांती... HOF \n", + "2 hasoc_hi_164 तुम जैसे हरामियों के लिए बस जूतों की कमी है शु... HOF \n", + "3 hasoc_hi_3530 बीजेपी MLA आकाश विजयवर्गीय जेल से रिहा, जमानत ... NOT \n", + "4 hasoc_hi_5206 चमकी बुखार: विधानसभा परिसर में आरजेडी का प्रदर... NOT \n", + "5 hasoc_hi_5121 मुंबई में बारिश से लोगों को काफी समस्या हो रही है NOT \n", + "6 hasoc_hi_7142 Ahmed's dad:-- beta aaj teri mammy kyu nahi ba... NOT \n", + "7 hasoc_hi_4321 5 लाख मुसलमान उर्स में, अजमेर की दरगाह पर आते ... NOT \n", + "8 hasoc_hi_4674 Do mahashaktiyan mili hain, charo taraf khusi ... NOT \n", + "9 hasoc_hi_1637 Chants of 'Jai Sri Ram' as Owaisi takes oath: ... NOT \n", + "\n", + " task_2 task_3 \n", + "0 NONE NONE \n", + "1 PRFN UNT \n", + "2 PRFN TIN \n", + "3 NONE NONE \n", + "4 NONE NONE \n", + "5 NONE NONE \n", + "6 NONE NONE \n", + "7 NONE NONE \n", + "8 NONE NONE \n", + "9 NONE NONE " + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D-mSJ8nUlupB" + }, + "source": [ + "## 1.2 Data preparation (0.5 + 0.5 points)\n", + "\n", + "* Prepare the data by removing everything that does not contain information. \n", + "User names (starting with '@') and punctuation symbols clearly do not convey information, but we also want to get rid of so-called [stopwords](https://en.wikipedia.org/wiki/Stop_word), i. e. words that have little to no semantic content (and, but, yes, the...). Hindi stopwords can be found [here](https://github.com/stopwords-iso/stopwords-hi/blob/master/stopwords-hi.txt) Then, standardize the spelling by lowercasing all words.\n", + "Do this for the development section of the corpus for now.\n", + "\n", + "* What about hashtags (starting with '#') and emojis? Should they be removed too? Justify your answer in the report, and explain how you accounted for this in your implementation." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "def punctuations_remove(input):\n", + " output = \"\".join([x for x in input if x not in string.punctuation])\n", + " return output\n", + "\n", + "def numbers_remove(input):\n", + " output = re.sub(r\"[0-9]+\", \"\", input)\n", + " return output\n", + "\n", + "def usernames_remove(input):\n", + " output = re.sub(r\"@\\S+\", \"\", input)\n", + " return output\n", + "\n", + "def hashtag_remove(input):\n", + " output = re.sub(r\"#\\S+\", \"\", input)\n", + " return output\n", + "\n", + "def http_remove(input):\n", + " output = re.sub(r\"http\\S+\", \"\", input)\n", + " return output\n", + "\n", + "def emojis_remove(input):\n", + " EMOJI_PATTERN = re.compile(\n", + " \"[\"\n", + " \"\\U0001F1E0-\\U0001F1FF\" # flags (iOS)\n", + " \"\\U0001F300-\\U0001F5FF\" # symbols & pictographs\n", + " \"\\U0001F600-\\U0001F64F\" # emoticons\n", + " \"\\U0001F680-\\U0001F6FF\" # transport & map symbols\n", + " \"\\U0001F700-\\U0001F77F\" # alchemical symbols\n", + " \"\\U0001F780-\\U0001F7FF\" # Geometric Shapes Extended\n", + " \"\\U0001F800-\\U0001F8FF\" # Supplemental Arrows-C\n", + " \"\\U0001F900-\\U0001F9FF\" # Supplemental Symbols and Pictographs\n", + " \"\\U0001FA00-\\U0001FA6F\" # Chess Symbols\n", + " \"\\U0001FA70-\\U0001FAFF\" # Symbols and Pictographs Extended-A\n", + " \"\\U00002702-\\U000027B0\" # Dingbats\n", + " \"\\U000024C2-\\U0001F251\" \n", + " \"]+\"\n", + " )\n", + " \n", + " output = EMOJI_PATTERN.sub(r'',input)\n", + " return output\n", + "\n", + "def extra_whitespaces(input):\n", + " output = input.replace('\\s+', ' ')\n", + " return output\n", + "\n", + "def stopwords_remove(m):\n", + " hindi_stopwords = pd.read_csv('https://raw.githubusercontent.com/stopwords-iso/stopwords-hi/master/stopwords-hi.txt').stack().tolist()\n", + " english_stopwords = pd.read_csv('https://raw.githubusercontent.com/stopwords-iso/stopwords-en/master/stopwords-en.txt').stack().tolist()\n", + " stopwords = hindi_stopwords + english_stopwords\n", + "\n", + " output = pd.Series(m).apply(lambda x: [item for item in x.split() if item not in stopwords])\n", + " return output\n", + "\n", + "def tolower(input):\n", + " output = input.lower()\n", + " return output\n", + "\n", + "def corpus_preprocess(corpus):\n", + " corpus = corpus.apply(lambda x: tolower(x))\n", + " corpus = corpus.apply(lambda x: emojis_remove(x))\n", + " corpus = corpus.apply(lambda x: http_remove(x))\n", + " corpus = corpus.apply(lambda x: hashtag_remove(x))\n", + " corpus = corpus.apply(lambda x: numbers_remove(x))\n", + " corpus = corpus.apply(lambda x: usernames_remove(x))\n", + " corpus = corpus.apply(lambda x: punctuations_remove(x))\n", + " corpus = corpus.apply(lambda x: stopwords_remove(x))\n", + " corpus = corpus.apply(lambda x: extra_whitespaces(x))\n", + " return corpus" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Started Preprocessing\n", + "Preprocessing ended!\n", + "Pre-processing the text took 0:00:08.811462\n", + "===========================================================\n" + ] + } + ], + "source": [ + "print(\"Started Preprocessing\")\n", + "cleanstart = time.time()\n", + "cleantext = corpus_preprocess(textd)\n", + "cleanend = str(datetime.timedelta(seconds = time.time()-cleanstart))\n", + "print(\"Preprocessing ended!\")\n", + "print(\"Pre-processing the text took {}\".format(cleanend))\n", + "print(\"===========================================================\")" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "def tokens(text):\n", + " c = []\n", + " for sent in text[0]:\n", + " a = \" \".join(sent)\n", + " c.append(a)\n", + " df_text = pd.DataFrame(c,columns=[\"text\"])\n", + " return df_text" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "text = tokens(cleantext)\n", + "text = text['text'].apply(lambda x: x.split())" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 [बांग्लादेश, शानदार, वापसी, भारत, रन, रोका]\n", + "1 [सब, रंडी, नाच, देखने, व्यस्त, होगा, सब, शुरू,...\n", + "2 [तुम, हरामियों, बस, जूतों, कमी, शुक्र, तुम्हार...\n", + "3 [बीजेपी, mla, आकाश, विजयवर्गीय, जेल, रिहा, जमा...\n", + "4 [चमकी, बुखार, विधानसभा, परिसर, आरजेडी, प्रदर्श...\n", + " ... \n", + "95 [देश, पहली, बार, सरकार, प्रो, इंकम्बेंसी, जनाद...\n", + "96 [आदमी, आदमी, मैं, पानी, बारे, सोचता, थालिखने, ...\n", + "97 [मादरजात, सनी, तेरे, पास, टाइम, नही, तोतेरी, म...\n", + "98 [थोर, क्रांतिकारक, राणी, लक्ष्मीबाई, यांना, पु...\n", + "99 [मुस्लिम, लोगों, वोट, मांगने, वाली, पार्टियां,...\n", + "Name: text, Length: 100, dtype: object" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Je09nozLmmMm" + }, + "source": [ + "## 1.3 Build the vocabulary (0.5 + 0.5 points)\n", + "\n", + "The input to the first layer of word2vec is an one-hot encoding of the current word. The output od the model is then compared to a numeric class label of the words within the size of the skip-gram window. Now\n", + "\n", + "* Compile a list of all words in the development section of your corpus and save it in a variable ```V```." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "VpoGmTKx-AOQ" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of unique words are: 1097\n" + ] + } + ], + "source": [ + "V = list(set(text.sum())) #List of unique words in the corpus\n", + "all_words = list(text.sum()) #All the words without removing duplicates\n", + "print(\"Total number of unique words are: \",len(V))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['स्पेशली',\n", + " 'दी।',\n", + " 'hai',\n", + " 'पहचान',\n", + " 'बताया',\n", + " 'khus',\n", + " 'साहिब',\n", + " 'जाती',\n", + " 'नया',\n", + " 'श्रद्धांजलि',\n", + " 'जनता',\n", + " 'कप',\n", + " 'सवाल',\n", + " 'तैसी',\n", + " 'wale',\n", + " 'हैँ',\n", + " 'भारतीय',\n", + " 'समर्थक',\n", + " 'शत्शत्',\n", + " 'तुमहृदय']" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V[:20]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "#Dictionaries of words and their indexes\n", + "word_index = {word: i for i,word in enumerate(V)}\n", + "index_word = {i: word for i,word in enumerate(V)}" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "152" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "word_index['मुखिया']" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'जाती'" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index_word[7]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WiaVglVNoENY" + }, + "source": [ + "* Then, write a function ```word_to_one_hot``` that returns a one-hot encoding of an arbitrary word in the vocabulary. The size of the one-hot encoding should be ```len(v)```." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "yqPNw6IT-AOQ" + }, + "outputs": [], + "source": [ + "def word_to_one_hot(word):\n", + " id = V.index(word)\n", + " onehot = [0.] * len(V)\n", + " onehot[id] = 1.\n", + " return torch.tensor(onehot)\n", + "\n", + "get_onehot = dict((word, word_to_one_hot(word)) for word in V)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0., 0., 0., ..., 0., 0., 0.])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_onehot['मुखिया']" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gKD8zBlxVclh" + }, + "source": [ + "## 1.4 Subsampling (0.5 points)\n", + "\n", + "The probability to keep a word in a context is given by:\n", + "\n", + "$P_{keep}(w_i) = \\Big(\\sqrt{\\frac{z(w_i)}{0.001}}+1\\Big) \\cdot \\frac{0.001}{z(w_i)}$\n", + "\n", + "Where $z(w_i)$ is the relative frequency of the word $w_i$ in the corpus. Now,\n", + "* Calculate word frequencies\n", + "* Define a function ```sampling_prob``` that takes a word (string) as input and returns the probabiliy to **keep** the word in a context." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "Mj4sDOVMMr0b" + }, + "outputs": [], + "source": [ + "def sampling_prob(word):\n", + " if word in all_words:\n", + " count = all_words.count(word)\n", + " zw_i = count / len(all_words)\n", + " p_wi_keep = (np.sqrt(zw_i/0.001) + 1)*(0.001/zw_i)\n", + " else:\n", + " p_wi_keep = 0\n", + " return p_wi_keep" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2.7641229712289954" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampling_prob('मुखिया')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kxV1P90zplxu" + }, + "source": [ + "# 1.5 Skip-Grams (1 point)\n", + "\n", + "Now that you have the vocabulary and one-hot encodings at hand, you can start to do the actual work. The skip gram model requires training data of the shape ```(current_word, context)```, with ```context``` being the words before and/or after ```current_word``` within ```window_size```. \n", + "\n", + "* Have closer look on the original paper. If you feel to understand how skip-gram works, implement a function ```get_target_context``` that takes a sentence as input and [yield](https://docs.python.org/3.9/reference/simple_stmts.html#the-yield-statement)s a ```(current_word, context)```.\n", + "\n", + "* Use your ```sampling_prob``` function to drop words from contexts as you sample them. " + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "r8CCTpVy-AOR" + }, + "outputs": [], + "source": [ + "def get_target_context(sentence,window):\n", + " thres = np.random.random()\n", + "\n", + " for i,word in enumerate(sentence):\n", + " target = word_index[sentence[i]]\n", + "\n", + " for j in range(i - window, i + window):\n", + " if j!=i and j <= len(sentence)-1 and j>=0:\n", + " if sampling_prob(sentence[j]) > thres:\n", + " context = word_index[sentence[j]]\n", + " yield target,context" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gfEFgtkmuDjL" + }, + "source": [ + "# 1.6 Hyperparameters (0.5 points)\n", + "\n", + "According to the word2vec paper, what would be a good choice for the following hyperparameters? \n", + "\n", + "* Embedding dimension\n", + "* Window size\n", + "\n", + "Initialize them in a dictionary or as independent variables in the code block below. " + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "d7xSKuFJcYoD" + }, + "outputs": [], + "source": [ + "# Set hyperparameters\n", + "window_size = 2\n", + "embedding_size = 300\n", + "\n", + "# More hyperparameters\n", + "learning_rate = 0.05\n", + "epochs = 300\n", + "batch_size = 60" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xiM2zq-YunPx" + }, + "source": [ + "# 1.7 Pytorch Module (0.5 + 0.5 + 0.5 points)\n", + "\n", + "Pytorch provides a wrapper for your fancy and super-complex models: [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). The code block below contains a skeleton for such a wrapper. Now,\n", + "\n", + "* Initialize the two weight matrices of word2vec as fields of the class.\n", + "\n", + "* Override the ```forward``` method of this class. It should take a one-hot encoding as input, perform the matrix multiplications, and finally apply a log softmax on the output layer.\n", + "\n", + "* Initialize the model and save its weights in a variable. The Pytorch documentation will tell you how to do that." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "D9sGNytYhwxS", + "outputId": "41645b64-e4ed-4e6a-e10f-74cb39b92230" + }, + "outputs": [], + "source": [ + "# Create model \n", + "class Word2Vec(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.v_len = len(V)\n", + " self.es = embedding_size\n", + " self.epochs = epochs\n", + " \n", + " self.w1 = nn.Linear(len(V),embedding_size,False)\n", + " self.w2 = nn.Linear(embedding_size,len(V))\n", + " self.soft = nn.LogSoftmax(dim = 1)\n", + "\n", + " def forward(self, one_hot):\n", + " one_hot = self.w1(one_hot)\n", + " one_hot=self.w2(one_hot)\n", + " output=self.soft(one_hot)\n", + " return output.cuda()\n", + "\n", + " def softmax(self,input): \n", + " output = self.soft(input)\n", + " return output" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XefIDMMHv5zJ" + }, + "source": [ + "# 1.8 Loss function and optimizer (0.5 points)\n", + "\n", + "Initialize variables with [optimizer](https://pytorch.org/docs/stable/optim.html#module-torch.optim) and loss function. You can take what is used in the word2vec paper, but you can use alternative optimizers/loss functions if you explain your choice in the report." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "id": "V9-Ino-e29w3" + }, + "outputs": [], + "source": [ + "# Define optimizer and loss\n", + "model = Word2Vec().cuda()\n", + "optimizer = optim.SGD(model.parameters(), lr=learning_rate,momentum=0.9,nesterov=True)\n", + "criterion = nn.NLLLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=====================================\n", + "The Word2Vec model: \n", + "Word2Vec(\n", + " (w1): Linear(in_features=1097, out_features=300, bias=False)\n", + " (w2): Linear(in_features=300, out_features=1097, bias=True)\n", + " (soft): LogSoftmax(dim=1)\n", + ")\n", + "=====================================\n" + ] + } + ], + "source": [ + "print(\"=====================================\")\n", + "print(\"The Word2Vec model: \")\n", + "print(model)\n", + "print(\"=====================================\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ckTfK78Ew8wI" + }, + "source": [ + "# 1.9 Training the model (3 points)\n", + "\n", + "As everything is prepared, implement a training loop that performs several passes of the data set through the model. You are free to do this as you please, but your code should:\n", + "\n", + "* Load the weights saved in 1.6 at the start of every execution of the code block\n", + "* Print the accumulated loss at least after every epoch (the accumulate loss should be reset after every epoch)\n", + "* Define a criterion for the training procedure to terminate if a certain loss value is reached. You can find the threshold by observing the loss for the development set.\n", + "\n", + "You can play around with the number of epochs and the learning rate." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "'''Gets the corpus and creates the training data with the target and its context\n", + "and returns a dataframe containing them in terms of their indexes.\n", + "'''\n", + "def get_training_data(corpus,window):\n", + " t,c = [],[]\n", + " for sentence in corpus:\n", + " data = get_target_context(sentence,window)\n", + " for i,j in data:\n", + " x = get_onehot[index_word[i]]\n", + " t.append(x)\n", + " c.append(j)\n", + " t_data = pd.DataFrame(list(zip(t,c)),columns=[\"target\",\"context\"])\n", + " return t_data" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "id": "LbMGD5L0mLDx" + }, + "outputs": [], + "source": [ + "def train(traindata,batchsize):\n", + " losses = []\n", + " print(\"Training started\")\n", + " for epoch in range(1,epochs+1):\n", + " total_loss = []\n", + " for wt,wc in zip(DataLoader(traindata.target.values,batch_size=batchsize),\n", + " DataLoader(traindata.context.values,batch_size=batchsize)):\n", + " wt = wt.cuda()\n", + " wc = wc.cuda()\n", + " optimizer.zero_grad()\n", + " output = model(wt)\n", + " loss = criterion(output,wc)\n", + " total_loss.append(loss.item())\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " if epoch % 50 == 0 :\n", + " start = time.time()\n", + " print(\"===========================================================\")\n", + " print(\"Saving the model state\")\n", + " save_model(epoch)\n", + " end = str(datetime.timedelta(seconds = time.time()-start))\n", + " print(\"Model state saved. It was completed in {}\".format(end)) \n", + " print(\"===========================================================\")\n", + "\n", + " if np.mean(total_loss) < 1.2:\n", + " break;\n", + " print(\"At epoch {} the loss is {}\".format(epoch ,round(np.mean(total_loss),3)))\n", + " losses.append(np.mean(total_loss))\n", + "\n", + " plt.xlabel(\"Epochs\")\n", + " plt.ylabel(\"LOSS\")\n", + " plt.plot(losses)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "def save_model(epoch):\n", + " torch.save(model.state_dict(),\"epoch_{}.pt\".format(epoch))" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=================================================\n", + "Collecting training data\n", + "It took 0:00:00.318723 to collect the data\n", + "The training data has 4028 target-context pairs\n", + "===================================================\n", + "Training started\n", + "At epoch 1 the loss is 7.006\n", + "At epoch 2 the loss is 6.99\n", + "At epoch 3 the loss is 6.974\n", + "At epoch 4 the loss is 6.96\n", + "At epoch 5 the loss is 6.946\n", + "At epoch 6 the loss is 6.932\n", + "At epoch 7 the loss is 6.92\n", + "At epoch 8 the loss is 6.907\n", + "At epoch 9 the loss is 6.895\n", + "At epoch 10 the loss is 6.884\n", + "At epoch 11 the loss is 6.872\n", + "At epoch 12 the loss is 6.861\n", + "At epoch 13 the loss is 6.849\n", + "At epoch 14 the loss is 6.837\n", + "At epoch 15 the loss is 6.825\n", + "At epoch 16 the loss is 6.812\n", + "At epoch 17 the loss is 6.798\n", + "At epoch 18 the loss is 6.783\n", + "At epoch 19 the loss is 6.767\n", + "At epoch 20 the loss is 6.751\n", + "At epoch 21 the loss is 6.736\n", + "At epoch 22 the loss is 6.723\n", + "At epoch 23 the loss is 6.711\n", + "At epoch 24 the loss is 6.699\n", + "At epoch 25 the loss is 6.688\n", + "At epoch 26 the loss is 6.677\n", + "At epoch 27 the loss is 6.666\n", + "At epoch 28 the loss is 6.654\n", + "At epoch 29 the loss is 6.643\n", + "At epoch 30 the loss is 6.631\n", + "At epoch 31 the loss is 6.618\n", + "At epoch 32 the loss is 6.605\n", + "At epoch 33 the loss is 6.591\n", + "At epoch 34 the loss is 6.576\n", + "At epoch 35 the loss is 6.561\n", + "At epoch 36 the loss is 6.544\n", + "At epoch 37 the loss is 6.527\n", + "At epoch 38 the loss is 6.508\n", + "At epoch 39 the loss is 6.489\n", + "At epoch 40 the loss is 6.468\n", + "At epoch 41 the loss is 6.446\n", + "At epoch 42 the loss is 6.423\n", + "At epoch 43 the loss is 6.398\n", + "At epoch 44 the loss is 6.373\n", + "At epoch 45 the loss is 6.345\n", + "At epoch 46 the loss is 6.317\n", + "At epoch 47 the loss is 6.287\n", + "At epoch 48 the loss is 6.256\n", + "At epoch 49 the loss is 6.224\n", + "===========================================================\n", + "Saving the model state\n", + "Model state saved. It was completed in 0:00:00.011237\n", + "===========================================================\n", + "At epoch 50 the loss is 6.19\n", + "At epoch 51 the loss is 6.155\n", + "At epoch 52 the loss is 6.118\n", + "At epoch 53 the loss is 6.08\n", + "At epoch 54 the loss is 6.04\n", + "At epoch 55 the loss is 5.999\n", + "At epoch 56 the loss is 5.956\n", + "At epoch 57 the loss is 5.912\n", + "At epoch 58 the loss is 5.867\n", + "At epoch 59 the loss is 5.82\n", + "At epoch 60 the loss is 5.772\n", + "At epoch 61 the loss is 5.722\n", + "At epoch 62 the loss is 5.671\n", + "At epoch 63 the loss is 5.619\n", + "At epoch 64 the loss is 5.565\n", + "At epoch 65 the loss is 5.51\n", + "At epoch 66 the loss is 5.454\n", + "At epoch 67 the loss is 5.397\n", + "At epoch 68 the loss is 5.339\n", + "At epoch 69 the loss is 5.28\n", + "At epoch 70 the loss is 5.219\n", + "At epoch 71 the loss is 5.159\n", + "At epoch 72 the loss is 5.097\n", + "At epoch 73 the loss is 5.035\n", + "At epoch 74 the loss is 4.972\n", + "At epoch 75 the loss is 4.909\n", + "At epoch 76 the loss is 4.845\n", + "At epoch 77 the loss is 4.781\n", + "At epoch 78 the loss is 4.717\n", + "At epoch 79 the loss is 4.652\n", + "At epoch 80 the loss is 4.587\n", + "At epoch 81 the loss is 4.522\n", + "At epoch 82 the loss is 4.458\n", + "At epoch 83 the loss is 4.393\n", + "At epoch 84 the loss is 4.328\n", + "At epoch 85 the loss is 4.263\n", + "At epoch 86 the loss is 4.198\n", + "At epoch 87 the loss is 4.134\n", + "At epoch 88 the loss is 4.07\n", + "At epoch 89 the loss is 4.005\n", + "At epoch 90 the loss is 3.942\n", + "At epoch 91 the loss is 3.878\n", + "At epoch 92 the loss is 3.815\n", + "At epoch 93 the loss is 3.752\n", + "At epoch 94 the loss is 3.69\n", + "At epoch 95 the loss is 3.628\n", + "At epoch 96 the loss is 3.567\n", + "At epoch 97 the loss is 3.506\n", + "At epoch 98 the loss is 3.446\n", + "At epoch 99 the loss is 3.386\n", + "===========================================================\n", + "Saving the model state\n", + "Model state saved. It was completed in 0:00:00.006575\n", + "===========================================================\n", + "At epoch 100 the loss is 3.326\n", + "At epoch 101 the loss is 3.268\n", + "At epoch 102 the loss is 3.21\n", + "At epoch 103 the loss is 3.153\n", + "At epoch 104 the loss is 3.096\n", + "At epoch 105 the loss is 3.04\n", + "At epoch 106 the loss is 2.985\n", + "At epoch 107 the loss is 2.931\n", + "At epoch 108 the loss is 2.877\n", + "At epoch 109 the loss is 2.825\n", + "At epoch 110 the loss is 2.773\n", + "At epoch 111 the loss is 2.722\n", + "At epoch 112 the loss is 2.673\n", + "At epoch 113 the loss is 2.624\n", + "At epoch 114 the loss is 2.576\n", + "At epoch 115 the loss is 2.53\n", + "At epoch 116 the loss is 2.484\n", + "At epoch 117 the loss is 2.44\n", + "At epoch 118 the loss is 2.398\n", + "At epoch 119 the loss is 2.356\n", + "At epoch 120 the loss is 2.316\n", + "At epoch 121 the loss is 2.278\n", + "At epoch 122 the loss is 2.241\n", + "At epoch 123 the loss is 2.206\n", + "At epoch 124 the loss is 2.172\n", + "At epoch 125 the loss is 2.14\n", + "At epoch 126 the loss is 2.109\n", + "At epoch 127 the loss is 2.081\n", + "At epoch 128 the loss is 2.053\n", + "At epoch 129 the loss is 2.028\n", + "At epoch 130 the loss is 2.003\n", + "At epoch 131 the loss is 1.98\n", + "At epoch 132 the loss is 1.959\n", + "At epoch 133 the loss is 1.938\n", + "At epoch 134 the loss is 1.919\n", + "At epoch 135 the loss is 1.9\n", + "At epoch 136 the loss is 1.883\n", + "At epoch 137 the loss is 1.867\n", + "At epoch 138 the loss is 1.851\n", + "At epoch 139 the loss is 1.836\n", + "At epoch 140 the loss is 1.822\n", + "At epoch 141 the loss is 1.809\n", + "At epoch 142 the loss is 1.797\n", + "At epoch 143 the loss is 1.785\n", + "At epoch 144 the loss is 1.773\n", + "At epoch 145 the loss is 1.763\n", + "At epoch 146 the loss is 1.752\n", + "At epoch 147 the loss is 1.742\n", + "At epoch 148 the loss is 1.733\n", + "At epoch 149 the loss is 1.724\n", + "===========================================================\n", + "Saving the model state\n", + "Model state saved. It was completed in 0:00:00.010483\n", + "===========================================================\n", + "At epoch 150 the loss is 1.716\n", + "At epoch 151 the loss is 1.708\n", + "At epoch 152 the loss is 1.7\n", + "At epoch 153 the loss is 1.693\n", + "At epoch 154 the loss is 1.686\n", + "At epoch 155 the loss is 1.679\n", + "At epoch 156 the loss is 1.673\n", + "At epoch 157 the loss is 1.667\n", + "At epoch 158 the loss is 1.661\n", + "At epoch 159 the loss is 1.656\n", + "At epoch 160 the loss is 1.65\n", + "At epoch 161 the loss is 1.645\n", + "At epoch 162 the loss is 1.64\n", + "At epoch 163 the loss is 1.636\n", + "At epoch 164 the loss is 1.631\n", + "At epoch 165 the loss is 1.627\n", + "At epoch 166 the loss is 1.623\n", + "At epoch 167 the loss is 1.619\n", + "At epoch 168 the loss is 1.616\n", + "At epoch 169 the loss is 1.612\n", + "At epoch 170 the loss is 1.609\n", + "At epoch 171 the loss is 1.605\n", + "At epoch 172 the loss is 1.602\n", + "At epoch 173 the loss is 1.599\n", + "At epoch 174 the loss is 1.596\n", + "At epoch 175 the loss is 1.593\n", + "At epoch 176 the loss is 1.591\n", + "At epoch 177 the loss is 1.588\n", + "At epoch 178 the loss is 1.586\n", + "At epoch 179 the loss is 1.583\n", + "At epoch 180 the loss is 1.581\n", + "At epoch 181 the loss is 1.579\n", + "At epoch 182 the loss is 1.577\n", + "At epoch 183 the loss is 1.574\n", + "At epoch 184 the loss is 1.572\n", + "At epoch 185 the loss is 1.571\n", + "At epoch 186 the loss is 1.569\n", + "At epoch 187 the loss is 1.567\n", + "At epoch 188 the loss is 1.565\n", + "At epoch 189 the loss is 1.564\n", + "At epoch 190 the loss is 1.562\n", + "At epoch 191 the loss is 1.56\n", + "At epoch 192 the loss is 1.559\n", + "At epoch 193 the loss is 1.558\n", + "At epoch 194 the loss is 1.556\n", + "At epoch 195 the loss is 1.555\n", + "At epoch 196 the loss is 1.553\n", + "At epoch 197 the loss is 1.552\n", + "At epoch 198 the loss is 1.551\n", + "At epoch 199 the loss is 1.55\n", + "===========================================================\n", + "Saving the model state\n", + "Model state saved. It was completed in 0:00:00.007587\n", + "===========================================================\n", + "At epoch 200 the loss is 1.549\n", + "At epoch 201 the loss is 1.548\n", + "At epoch 202 the loss is 1.547\n", + "At epoch 203 the loss is 1.546\n", + "At epoch 204 the loss is 1.545\n", + "At epoch 205 the loss is 1.544\n", + "At epoch 206 the loss is 1.543\n", + "At epoch 207 the loss is 1.542\n", + "At epoch 208 the loss is 1.541\n", + "At epoch 209 the loss is 1.54\n", + "At epoch 210 the loss is 1.539\n", + "At epoch 211 the loss is 1.538\n", + "At epoch 212 the loss is 1.537\n", + "At epoch 213 the loss is 1.537\n", + "At epoch 214 the loss is 1.536\n", + "At epoch 215 the loss is 1.535\n", + "At epoch 216 the loss is 1.534\n", + "At epoch 217 the loss is 1.534\n", + "At epoch 218 the loss is 1.533\n", + "At epoch 219 the loss is 1.532\n", + "At epoch 220 the loss is 1.532\n", + "At epoch 221 the loss is 1.531\n", + "At epoch 222 the loss is 1.53\n", + "At epoch 223 the loss is 1.53\n", + "At epoch 224 the loss is 1.529\n", + "At epoch 225 the loss is 1.529\n", + "At epoch 226 the loss is 1.528\n", + "At epoch 227 the loss is 1.528\n", + "At epoch 228 the loss is 1.527\n", + "At epoch 229 the loss is 1.527\n", + "At epoch 230 the loss is 1.526\n", + "At epoch 231 the loss is 1.525\n", + "At epoch 232 the loss is 1.525\n", + "At epoch 233 the loss is 1.524\n", + "At epoch 234 the loss is 1.524\n", + "At epoch 235 the loss is 1.524\n", + "At epoch 236 the loss is 1.523\n", + "At epoch 237 the loss is 1.523\n", + "At epoch 238 the loss is 1.522\n", + "At epoch 239 the loss is 1.522\n", + "At epoch 240 the loss is 1.521\n", + "At epoch 241 the loss is 1.521\n", + "At epoch 242 the loss is 1.521\n", + "At epoch 243 the loss is 1.52\n", + "At epoch 244 the loss is 1.52\n", + "At epoch 245 the loss is 1.519\n", + "At epoch 246 the loss is 1.519\n", + "At epoch 247 the loss is 1.519\n", + "At epoch 248 the loss is 1.518\n", + "At epoch 249 the loss is 1.518\n", + "===========================================================\n", + "Saving the model state\n", + "Model state saved. It was completed in 0:00:00.005157\n", + "===========================================================\n", + "At epoch 250 the loss is 1.518\n", + "At epoch 251 the loss is 1.517\n", + "At epoch 252 the loss is 1.517\n", + "At epoch 253 the loss is 1.517\n", + "At epoch 254 the loss is 1.516\n", + "At epoch 255 the loss is 1.516\n", + "At epoch 256 the loss is 1.516\n", + "At epoch 257 the loss is 1.515\n", + "At epoch 258 the loss is 1.515\n", + "At epoch 259 the loss is 1.515\n", + "At epoch 260 the loss is 1.514\n", + "At epoch 261 the loss is 1.514\n", + "At epoch 262 the loss is 1.514\n", + "At epoch 263 the loss is 1.514\n", + "At epoch 264 the loss is 1.513\n", + "At epoch 265 the loss is 1.513\n", + "At epoch 266 the loss is 1.513\n", + "At epoch 267 the loss is 1.512\n", + "At epoch 268 the loss is 1.512\n", + "At epoch 269 the loss is 1.512\n", + "At epoch 270 the loss is 1.512\n", + "At epoch 271 the loss is 1.511\n", + "At epoch 272 the loss is 1.511\n", + "At epoch 273 the loss is 1.511\n", + "At epoch 274 the loss is 1.511\n", + "At epoch 275 the loss is 1.51\n", + "At epoch 276 the loss is 1.51\n", + "At epoch 277 the loss is 1.51\n", + "At epoch 278 the loss is 1.51\n", + "At epoch 279 the loss is 1.51\n", + "At epoch 280 the loss is 1.509\n", + "At epoch 281 the loss is 1.509\n", + "At epoch 282 the loss is 1.509\n", + "At epoch 283 the loss is 1.509\n", + "At epoch 284 the loss is 1.509\n", + "At epoch 285 the loss is 1.508\n", + "At epoch 286 the loss is 1.508\n", + "At epoch 287 the loss is 1.508\n", + "At epoch 288 the loss is 1.508\n", + "At epoch 289 the loss is 1.508\n", + "At epoch 290 the loss is 1.507\n", + "At epoch 291 the loss is 1.507\n", + "At epoch 292 the loss is 1.507\n", + "At epoch 293 the loss is 1.507\n", + "At epoch 294 the loss is 1.507\n", + "At epoch 295 the loss is 1.506\n", + "At epoch 296 the loss is 1.506\n", + "At epoch 297 the loss is 1.506\n", + "At epoch 298 the loss is 1.506\n", + "At epoch 299 the loss is 1.506\n", + "===========================================================\n", + "Saving the model state\n", + "Model state saved. It was completed in 0:00:00.009836\n", + "===========================================================\n", + "At epoch 300 the loss is 1.506\n", + "Training finished.\n", + "It took 0:00:40.983541 to finish training the model\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAEGCAYAAABvtY4XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgqUlEQVR4nO3dd3gd5Zn+8e9zjqoly1UW7pKxjbHBGCMMGENiO6YlWVJIISQbAgkhMQR+7CYhu79NsrvZluxmCSlwOZhO6HUTwtIJxWDLBRcM7hUXuUuWrfrsH2dkhJEsy9JoTrk/13WuM2dmdOYZj3yf0XvmfcfcHRERST+xqAsQEZFwKOBFRNKUAl5EJE0p4EVE0pQCXkQkTWVFXUBL/fv399LS0qjLEBFJGfPnz9/h7sWtLUuqgC8tLaWioiLqMkREUoaZrW9rmZpoRETSlAJeRCRNKeBFRNKUAl5EJE0p4EVE0lRoAW9mJ5jZohaPfWZ2fVjbExGRDwvtMkl3fw+YAGBmcWAz8HhY2xMRkQ/rriaa6cBqd2/zes3OuPmFlVSs2xXGW4uIpKzuCvgvA/e3tsDMrjKzCjOrqKys7PAb7z1Qz31vreeSW+fw7XsqWF1Z3dlaRUTSgoV9ww8zywHeB8a5+7YjrVteXu7H0pO1pq6B2a+u5dZXVnOwoYnPTxzMVecez8gBhcdYtYhIajCz+e5e3tqy7hiq4EJgQXvh3hk9crK4dvooLj1jGL95cRX3z93AQxWbmDG2hKs/djynDe8T1qZFRJJWd5zBPwD8r7vf0d66x3oGf7id1bXcNWc9d89Zx56aek4b3ofLJ5dywUnHkR3XlaEikj6OdAYfasCbWQGwARjh7nvbW7+rAr7Z/toGHqrYyJ1vrGP9zhpKinL52pnDuXTSMPoV5nbZdkREohJZwHdUVwd8s6Ym5+UV27nj9XW8unIHOVkx/uqUQVw+uZSTBvfq8u2JiHSXqNvgIxeLGdPGlDBtTAmrtldx1xvreXTBJh6Zv4nTS/tw+eQyzh9XQpaab0QkjWTEGXxr9h6o5+GKjdw1Zx0bdx1gSJ98bpgxmosnDCYes26pQUSkszK+ieZIGpucF5Zv4+YXV7J08z5GlxTyt+edwIyxJZgp6EUkuR0p4DO+TSIeM84bdxxPzZzCb78ykYZG56p75vO5W97g1ZWVJNMHoIhIR2T8GfzhGhqbeGT+Jn71wkq27D3IiOICLjtjOJdMHEKvHtmR1iYicjg10RyDg/WNPL1kC/e+uZ4FG/aQG1x587WzhjN+SO+oyxMRARTwnbbs/b3c99YGnli4mZq6Rk4Z0ouvnjmcT58yiLzseNTliUgGU8B3kaqD9Ty+cDN3z1nPqu3V9O6RzZfKh3LllDIGFOVFXZ6IZCAFfBdzd95cs4t731zPM8u2Eo8ZX5k0jG9/bAQDe+VHXZ6IZJCM7+jU1cyMs47vx1nH92P9zv387qXV3Pvmev4wdwPfOqeM7358JAW5+qcVkWjpDL6LbNxVwy+fW8HjCzdTUpTLjy48kYsnDNK19CISKl0H3w2G9u3Bf39pAo9+ZzIlRXlc/+AivnlXBZVVtVGXJiIZSgHfxU4b3ocnvns2P/7UWF5btYPzb/oLzyzdEnVZIpKBFPAhiMWMK6aU8cdrpzCodx5X37uAf/7jOzQ0NkVdmohkEAV8iEaV9OSx75zN5ZNLmf3aWr5x5zz21tRHXZaIZAgFfMhysmL89K/G8e+fO5k31+zkM797nfU790ddlohkAAV8N/nypGH84Vtnsqemji/cOoeV26qiLklE0pwCvhudXtqXB799Fg58adabLN3c7l0MRUSOmQK+m40u6clD3z6LvKwYl/7+TRZv2hN1SSKSphTwESjrX8BDV59Fr/xsLr9jHmsqq6MuSUTSkAI+IkP69ODuKyZhwNdmz2XbvoNRlyQiaUYBH6ERxYXc+Y1J7Kmp4+u3z2XvAV1CKSJdRwEfsZOH9GLWX5ezurKa792/kMam5BkbSERSmwI+CZw9sj//dPFJvLKikv989r2oyxGRNKExbZPEpZOGsWTzXm55eTXjBhXxqfGDoi5JRFKczuCTyE8/PY7Thvfh+w8v5t2t+6IuR0RSnAI+ieRkxbjlsokU5mVx7R8WcrC+MeqSRCSFhRrwZtbbzB4xs3fNbLmZnRXm9tLBgKI8fvnFU1i5vZp/fXp51OWISAoL+wz+V8Az7j4GOAVQYh2Fc0YV880pZdw9Zz0vLN8WdTkikqJCC3gz6wWcC8wGcPc6d98T1vbSzfcvOIGxA4v4/iOL2V6lTlAi0nFhnsGXAZXAHWa20MxuM7OCw1cys6vMrMLMKiorK0MsJ7XkZsW5+dIJ1NQ18HePLSWZ7p0rIqkhzIDPAiYCt7j7qcB+4MbDV3L3We5e7u7lxcXFIZaTekYO6MkNM0bz/PJtPL1ka9TliEiKCTPgNwGb3P2t4PUjJAJfOuCKs8s4eXAvfvLUUvbU1EVdjoikkNAC3t23AhvN7IRg1nTgnbC2l66y4jH+4/Pj2V1Tz8/+pO+oReTohX0VzbXAfWa2GJgA/GvI20tLYwcV8e1zR/DI/E28ulLfU4jI0Qk14N19UdC+Pt7dP+Puu8PcXjr73vRRlPUv4CdPLqOuoSnqckQkBagna4rIy47z40+PZc2O/dz5xtqoyxGRFKCATyFTTxjA9DED+NXzK9muG4SISDsU8CnmHz41lvpG59+feTfqUkQkySngU0xp/wKuPKeMxxZsZv76XVGXIyJJTAGfgq6ZOpKSolz++Y/L1cNVRNqkgE9BBblZ3DBjNIs27uGZperhKiKtU8CnqM9PHMKoAYX8/H/fo75Rl02KyEcp4FNUVjzGDy4Yw9od+3lw3saoyxGRJKSAT2GfOHEAp5f24abnV7K/tiHqckQkySjgU5iZceOFY9hRXcvs19T5SUQ+TAGf4k4b3pcZY0v4/atr2FtTH3U5IpJEFPBp4P99YjRVBxuY/dqaqEsRkSSigE8DYwcVceFJx3H76+s0ZryIHKKATxPXfWIU1bUN3Paq2uJFJEEBnybGHFfEJ8cP5I7X17Jrv87iRUQBn1aunz6KmvpGfv+q2uJFRAGfVkaV9OSTJw/knjnrdUWNiCjg0813Pz6S6toG7p6zLupSRCRiCvg0M3ZQEdPGDOD219dSU6ferSKZTAGfhr778ePZXVPPA3M1Ro1IJlPAp6Hy0r5MKuvLrL+s0Q26RTKYAj5NzZw6kq37DvL4wk1RlyIiEVHAp6lzR/Vn3KAibn1lDY1NuuuTSCZSwKcpM2Pm1JGs3bGfPy/dEnU5IhIBBXwaO3/ccYwoLuB3L63WvVtFMpACPo3FY8ZV54zgnS37mLN6Z9TliEg3U8Cnuc+cOpj+hTnM0vAFIhkn1IA3s3VmtsTMFplZRZjbktblZcf567NKefm9SlZsq4q6HBHpRt1xBj/V3Se4e3k3bEta8dUzh5OXHeM2ncWLZBQ10WSAvgU5XHLaEJ5Y+D7bqw5GXY6IdJOwA96BZ81svpld1doKZnaVmVWYWUVlZWXI5WSuK6eMoL6pibvfWB91KSLSTcIO+CnuPhG4EJhpZucevoK7z3L3cncvLy4uDrmczFXWv4DzxpZwz5vrNQiZSIYINeDdfXPwvB14HJgU5vbkyL51zgj2Hqjn4QoNXyCSCUILeDMrMLOezdPAecDSsLYn7TtteB9OHdab2a+t1fAFIhkgzDP4EuA1M3sbmAv8yd2fCXF70g6zRMenDbtqeHbZ1qjLEZGQZYX1xu6+BjglrPeXY3PeuOMY1rcHs15dw4UnD4y6HBEJkS6TzDDxmHHF2aUs3LCH+et3R12OiIRIAZ+BvlA+lJ55Wdz+2tqoSxGRECngM1BBbhZfOWMYf166hY27aqIuR0RCooDPUJdPLiVmxl1vrIu6FBEJiQI+Qw3slc9FJw/kgXkbqTpYH3U5IhICBXwG++Y5ZVTXNvDgvI1RlyIiIVDAZ7DxQ3ozqbQvd7y+jobGpqjLEZEupoDPcFeeU8bmPQd49p1tUZciIl1MAZ/hPnFiCcP69tBY8SJpSAGf4Zo7Pi3YsIcFG9TxSSSdKODlUMen2er4JJJWFPCS6Pg0aRh/XqKOTyLp5IgBb2afNrPhLV7/2MzeNrOnzKws/PKku3x9cimmjk8iaaW9M/h/ASoBzOxTwFeBK4CngFvDLU2606De+XxSHZ9E0kp7Ae/u3vw3++eA2e4+391vA3R/vTTT3PHpId3xSSQttBfwZmaFZhYDpgMvtFiWF15ZEoXxQ3pzemkf7nh9rTo+iaSB9gL+JmARUAEsd/cKADM7FdgSamUSiSunjGDTbnV8EkkHRwx4d78d+BhwJXBRi0VbgG+EWJdEZMbYRMcnXTIpkvrau4pmOFDt7gvdvcnMpprZr4CvALqpZxpq7vg0f/1udXwSSXHtNdE8BBQAmNkE4GFgA4l7rf4u1MokMur4JJIe2gv4fHd/P5j+KnC7u/8XieaZSaFWJpFp2fFp0251fBJJVe1eRdNiehrBVTTurkss0pw6PomkvvYC/kUzeyhod+8DvAhgZgOBurCLk+gc6vg0Vx2fRFJVewF/PfAYsA6Y4u7N/9OPA/4+vLIkGVw5pYwqdXwSSVntXSbp7v4A8ARwqpl9ysxGBFfV/G+3VCiROWXoBx2fGps86nJEpIPau0yyyMweAp4nMQbNFcDzZvawmRV1R4ESrUMdn5bpqliRVNNeE83NwDvAKHf/nLt/DjgeWAL8JuziJHrNHZ9u0yWTIimnvYA/291/2vKqmaDZ5p+As45mA2YWN7OFZvbHzhQq0YjHjG8EHZ8WquOTSErpzA0/rP1VALgOWN6J7UjE1PFJJDW1F/BvBDf5+FCYm9k/AHPae3MzGwJ8Erjt2EuUqBU2d3xaulUdn0RSSHsBfy1wMrDKzB4NHqtJDFVwzVG8/03AD4A2O0aZ2VVmVmFmFZWVlUdZtnS3r08uBVDHJ5EU0t5lkvvc/QvAecCdweM8d7+EdkaTDO4Atd3d57ezjVnuXu7u5cXFuodIshrUO5+Lgo5P1bUNUZcjIkfhqNrg3X21u/9P8FgdzL6hnR87G/grM1sHPABMM7N7j71Uidqhjk/zNkZdiogchdC+ZHX3H7n7EHcvBb4MvOjuX+3E9iRiE4KOT7er45NISuhMwOt/eAa6ckqZOj6JpIj2erJWmdm+Vh5VwKCj3Yi7v+zun+p0tRK5GWOPY2jffF0yKZIC2vuStae7F7Xy6OnuWd1VpCSPxB2fyqhQxyeRpNeZJhrJUOr4JJIaFPDSYYW5WVwadHzavOdA1OWISBsU8HJM1PFJJPkp4OWYDA46Pt3/1gZ1fBJJUgp4OWbq+CSS3BTwcswmDO1N+fA+3PGGOj6JJCMFvHTKN88pY+OuAzz3jjo+iSQbBbx0SnPHp9te1SWTIslGAS+dEo8Z35ic6Pi0aOOeqMsRkRYU8NJpXzx9KD1z1fFJJNko4KXTCnOzuPSMYTy9ZIs6PokkEQW8dInLJ5diwO//sibqUkQkoICXLjGodz6fPXUwD8zbwI7q2qjLEREU8NKFvvPx46ltaOJ2tcWLJAUFvHSZEcWFXHTyQO6Zs569B+qjLkck4yngpUvN/PhIqmobuGfOuqhLEcl4CnjpUmMHFTFtzABmv7aWmjoNQiYSJQW8dLmZU0eyu6ae++dqEDKRKCngpcudNrwPZ47oy6y/rKa2oTHqckQylgJeQnHN1FFs21fLYws2R12KSMZSwEsozh7Zj1OG9OKWl1fT0NgUdTkiGUkBL6EwM2ZOHcmGXTX8z+L3oy5HJCMp4CU0nzixhBMHFnHzC6t0Fi8SAQW8hCYWM66bPoq1O/bz5CKdxYt0NwW8hOr8cSWMHVjEr19cqbN4kW6mgJdQmRnXf2IU63bW8PhCXVEj0p1CC3gzyzOzuWb2tpktM7N/DGtbktxmjC3hpMFF/PrFVdTrLF6k24R5Bl8LTHP3U4AJwAVmdmaI25MkZWZcP300G3bV8LiuixfpNqEFvCdUBy+zg4eHtT1JbtNPHMD4Ib24+cWVOosX6SahtsGbWdzMFgHbgefc/a1W1rnKzCrMrKKysjLMciRCzW3xm3Yf4OGKTVGXI5IRQg14d2909wnAEGCSmZ3Uyjqz3L3c3cuLi4vDLEciNvWEAUwc1ptfvbCCA3Uao0YkbN1yFY277wFeAi7oju1JcjIzbrzwRLbtq+XON9ZFXY5I2gvzKppiM+sdTOcDM4B3w9qepIZJZX2ZekIxt7y8ir01uuuTSJjCPIMfCLxkZouBeSTa4P8Y4vYkRfzggjFU1Tbwu1dWRV2KSFrLCuuN3X0xcGpY7y+p68SBRXxmwmDufH0dl08uZWCv/KhLEklL6skqkbhhxmia3PnV8yujLkUkbSngJRJD+/bgsjOG81DFRlZtr27/B0SkwxTwEplrp42kR04W//7n5VGXIpKWFPASmX6FucycOpLnl2/n1ZXq5CbS1RTwEqkrppQyrG8P/ul/3tFwwiJdTAEvkcrNivN3F53Iyu3V3PfWhqjLEUkrCniJ3PnjSph8fD/++/kV7Kmpi7ockbShgJfImRk//vRY9h2o5yZdNinSZRTwkhTGHFfEpZOGcc+b61mxrSrqckTSggJeksYNM0ZTkBPnJ08uw123DhDpLAW8JI1+hbn88MIxzFmzk8d05yeRTlPAS1K59PRhTBzWm5/96R127dcXriKdoYCXpBKLGf/6uZOpOtjAvz2tHq4inaGAl6Qz5rgivnXuCB6ev4k5q3dGXY5IylLAS1L63rRRDOvbgxsfW8z+2oaoyxFJSQp4SUr5OXF+ccl4Nuyq4d80GJnIMVHAS9I6Y0Q/vjmljHvf3MArKzQYmUhHKeAlqf3NeScwakAhP3jkbd3DVaSDFPCS1PKy4/zyixPYWV3H3z2xRB2gRDpAAS9J7+QhvbjhvNH8afEW7p6zPupyRFKGAl5SwtXnHs/0MQP42Z/eYcGG3VGXI5ISFPCSEmIx45dfnEBJUR7X3LdAvVxFjoICXlJGrx7Z3HLZaeyoruPa+xdQrztAiRyRAl5SyslDevEvnz2J11ft5P8/vlRfuoocQVbUBYh01BfKh7JhVw2/fnEVQ/vmc820UVGXJJKUFPCSkm6YMZqNu2r4z2dX0DMvm69PLo26JJGko4CXlGRm/OILp7C/rpGfPLWM/Ow4Xzx9aNRliSSV0NrgzWyomb1kZu+Y2TIzuy6sbUlmyo7H+M1XTuXc0cX84NHF3PXGuqhLEkkqYX7J2gD8jbuPBc4EZprZ2BC3JxkoNyvOrK+dxnljS/jJU8u46fkV+uJVJBBawLv7FndfEExXAcuBwWFtTzJXXnac3102kUtOG8JNz6/kmvsXUlOnIYZFuuUySTMrBU4F3mpl2VVmVmFmFZWVGjFQjk1WPMYvLhnPjy4cw9NLtnDJLXPYtLsm6rJEIhV6wJtZIfAocL277zt8ubvPcvdydy8vLi4OuxxJY2bGtz92PHdcfjobd9fwyZtf46m334+6LJHIhBrwZpZNItzvc/fHwtyWSLOPnzCAp66ZQln/Ar53/0Jm/mEBuzW0gWSgMK+iMWA2sNzdfxnWdkRaU9a/gEeuPovvn38Czy7byoz/foUH522gqUlfwErmCPMM/mzga8A0M1sUPC4KcXsiH5IVjzFz6kienDmF4f0K+OGjS7j4t68zb92uqEsT6RaWTJeUlZeXe0VFRdRlSBpyd556+33+7el32brvIOeM6s9100dRXto36tJEOsXM5rt7eavLFPCSSWrqGrhnznpm/WUNO/fXceaIvlxxdhnTTywhHrOoyxPpMAW8yGFq6hr4w1sbmP3aWrbsPcjg3vlcduYwPnvqYAb2yo+6PJGjpoAXaUNDYxPPvbONu+esZ86anZjBGWV9+cyEwVx40kB69ciOukSRI1LAixyFtTv28+SizTy56H3W7thPdtyYVNaXqScMYNqYAYwoLoy6RJGPUMCLdIC7s2TzXv60eAsvvrudldurARjerwdnjejH6aV9mVTWlyF98klcDSwSHQW8SCds3FXDy+9t55UVlcxdu4t9BxPj3Azslcf4Ib0YN6gXJw0u4qRBvRhQlBdxtZJpFPAiXaSpyVmxvYq5a3cxb91ulm3ey5od+w8t71+Yy+iSQkYUFzCif+L5+OJCBvXO11U6EgoFvEiIqg7Ws3xLFUs372XZ+/tYVVnNmspqqg5+MKJlblaMwX3yGdw7n0G98hnUO59BvfOC53xKinLpkaP770jHHSng9Rsl0kk987KZVJZol2/m7uyormNNZTVrduxnTWU1m/ccYPOeg7y7dTuVVbUfeZ/87Dj9CnPoX5hL/8Ic+hXk0q8wh36FufQtyKYoL5ui/ObnLIrysumRE9f3ANImBbxICMyM4p65FPfM5YwR/T6yvLahka17D/L+noNs3nOAyqpadlbXsnN/HTuqa3l/z0EWb9rLrv11NBxh/Jx4zCjKy6JnEPo9cxOhn5cTp0d2nB45cfJzshLP2XHycxLzmufnZ8fJzYqR0/yIx8jNjpEbjx+ap6al1KWAF4lAblac4f0KGN6v4IjrNTU5ew/Us+dAPfsO1LPvYD1VBxsOTe870BA817PvYANVB+vZVlVPTV0jB+oaDz3XNTYdc61ZMfvIB0BOPEZOVvAhEDfiMSM7nvgwyIrFyD5sXvPrD5Z9dJ2sj6wfIytmxGJG3IyY8cF0DGKW+PmY2aHpeCzx4RoPXpslPgQTPxOs18rPJ55bvH+wzgfvk5ofcgp4kSQWixl9CnLoU5DTqfdpaGyipr6Rg0Ho19Q1cqC+gZq6RuoamhKPxiZqGxKPuhaP2obGQ8vrWiyvDZY1NjkNjc7+hgYagumGpqbg2WlobEo8tzKdSoN7mhF8mCQ+RIwPXsfMwD782oL1Dr0meB374HXzev0Kcnno6rO6vGYFvEgGyIrHKIrHKMpLrp65TUHYNzY59U1NNDYGz4c+KBLL3J1GT0w3NUFT8LopWN7kwbwmp8k9mOajP+uJnz/0sx78bNNHf7b5/RqbHCfxvYoH22ny4HWwD83bBw69R2KdxHptvg7W75kbThQr4EUkMrGYkRO08ecTj7ia9NMt92QVEZHup4AXEUlTCngRkTSlgBcRSVMKeBGRNKWAFxFJUwp4EZE0pYAXEUlTSTVcsJlVAuuP8cf7Azu6sJwoaV+ST7rsB2hfktWx7stwdy9ubUFSBXxnmFlFW2MipxrtS/JJl/0A7UuyCmNf1EQjIpKmFPAiImkqnQJ+VtQFdCHtS/JJl/0A7Uuy6vJ9SZs2eBER+bB0OoMXEZEWFPAiImkq5QPezC4ws/fMbJWZ3Rh1PR1lZuvMbImZLTKzimBeXzN7zsxWBs99oq6zNWZ2u5ltN7OlLea1Wrsl3Bwcp8VmNjG6yj+qjX35qZltDo7NIjO7qMWyHwX78p6ZnR9N1a0zs6Fm9pKZvWNmy8zsumB+yh2bI+xLyh0bM8szs7lm9nawL/8YzC8zs7eCmh80s5xgfm7welWwvLTDG03chio1H0AcWA2MAHKAt4GxUdfVwX1YB/Q/bN7PgRuD6RuB/4i6zjZqPxeYCCxtr3bgIuDPgAFnAm9FXf9R7MtPgb9tZd2xwe9aLlAW/A7Go96HFvUNBCYG0z2BFUHNKXdsjrAvKXdsgn/fwmA6G3gr+Pd+CPhyMP9W4DvB9HeBW4PpLwMPdnSbqX4GPwlY5e5r3L0OeAC4OOKausLFwF3B9F3AZ6IrpW3u/hdg12Gz26r9YuBuT3gT6G1mA7ul0KPQxr605WLgAXevdfe1wCoSv4tJwd23uPuCYLoKWA4MJgWPzRH2pS1Je2yCf9/q4GV28HBgGvBIMP/w49J8vB4BppuZdWSbqR7wg4GNLV5v4sgHPxk58KyZzTezq4J5Je6+JZjeCpREU9oxaav2VD1W1wTNFre3aCpLmX0J/qw/lcTZYkofm8P2BVLw2JhZ3MwWAduB50j8hbHH3RuCVVrWe2hfguV7gX4d2V6qB3w6mOLuE4ELgZlmdm7LhZ74+ywlr2VN5doDtwDHAxOALcB/RVpNB5lZIfAocL2772u5LNWOTSv7kpLHxt0b3X0CMITEXxZjwtxeqgf8ZmBoi9dDgnkpw903B8/bgcdJHPRtzX8iB8/bo6uww9qqPeWOlbtvC/5DNgG/54M/9ZN+X8wsm0Qg3ufujwWzU/LYtLYvqXxsANx9D/AScBaJJrGsYFHLeg/tS7C8F7CzI9tJ9YCfB4wKvoXOIfFFxFMR13TUzKzAzHo2TwPnAUtJ7MPXg9W+DjwZTYXHpK3anwL+Orhi40xgb4vmgqR0WDv0Z0kcG0jsy5eDqxzKgFHA3O6ury1BO+1sYLm7/7LFopQ7Nm3tSyoeGzMrNrPewXQ+MIPEdwovAZcEqx1+XJqP1yXAi8FfXkcv6m+Wu+Cb6YtIfLO+Gvj7qOvpYO0jSHzj/zawrLl+Eu1sLwArgeeBvlHX2kb995P487ieRNvhlW3VTuIKgt8Gx2kJUB51/UexL/cEtS4O/rMNbLH+3wf78h5wYdT1H7YvU0g0vywGFgWPi1Lx2BxhX1Lu2ADjgYVBzUuBHwfzR5D4EFoFPAzkBvPzgterguUjOrpNDVUgIpKmUr2JRkRE2qCAFxFJUwp4EZE0pYAXEUlTCngRkTSlgJe0Z2aNLUYdXGRdOOqomZW2HIFSJJlktb+KSMo74Inu4SIZRWfwkrEsMRb/zy0xHv9cMxsZzC81sxeDgaxeMLNhwfwSM3s8GM/7bTObHLxV3Mx+H4zx/WzQSxEz+14wjvliM3sgot2UDKaAl0yQf1gTzZdaLNvr7icDvwFuCub9GrjL3ccD9wE3B/NvBl5x91NIjB2/LJg/Cvitu48D9gCfD+bfCJwavM/V4eyaSNvUk1XSnplVu3thK/PXAdPcfU0woNVWd+9nZjtIdH2vD+Zvcff+ZlYJDHH32hbvUQo85+6jgtc/BLLd/Wdm9gxQDTwBPOEfjAUu0i10Bi+ZztuY7ojaFtONfPDd1idJjPEyEZjXYsRAkW6hgJdM96UWz3OC6TdIjEwKcBnwajD9AvAdOHTjhl5tvamZxYCh7v4S8EMSQ71+5K8IkTDpjEIyQX5wF51mz7h786WSfcxsMYmz8EuDedcCd5jZ94FK4BvB/OuAWWZ2JYkz9e+QGIGyNXHg3uBDwICbPTEGuEi3URu8ZKygDb7c3XdEXYtIGNREIyKSpnQGLyKSpnQGLyKSphTwIiJpSgEvIpKmFPAiImlKAS8ikqb+DwqRNr0rGLthAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "start2 = time.time()\n", + "print(\"=================================================\")\n", + "print(\"Collecting training data\")\n", + "starte = time.time()\n", + "data = get_training_data(text,window_size)\n", + "ende = str(datetime.timedelta(seconds = time.time()-starte))\n", + "print(\"It took {} to collect the data\".format(ende))\n", + "print(\"The training data has {} target-context pairs\".format(len(data)))\n", + "print(\"===================================================\")\n", + "\n", + "train(data,batch_size)\n", + "\n", + "end2 = str(datetime.timedelta(seconds = time.time()-start2))\n", + "print(\"Training finished.\")\n", + "print(\"It took {} to finish training the model\".format(end2))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BgQkaYstyj0Q" + }, + "source": [ + "# 1.10 Train on the full dataset (0.5 points)\n", + "\n", + "Now, go back to 1.1 and remove the restriction on the number of sentences in your corpus. Then, reexecute code blocks 1.2, 1.3 and 1.6 (or those relevant if you created additional ones). \n", + "\n", + "* Then, retrain your model on the complete dataset.\n", + "\n", + "* Now, the input weights of the model contain the desired word embeddings! Save them together with the corresponding vocabulary items (Pytorch provides a nice [functionality](https://pytorch.org/tutorials/beginner/saving_loading_models.html) for this)." + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "id": "4x8hQP_bg4_g" + }, + "outputs": [], + "source": [ + "data_full = pd.read_csv('https://raw.githubusercontent.com/SouravDutta91/NNTI-WS2021-NLP-Project/main/data/hindi_hatespeech.tsv',sep='\\t')\n", + "full_data = data_full['text']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Started Preprocessing\")\n", + "cleantext_full = corpus_preprocess(full_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "text_full = tokens(cleantext_full)\n", + "text_full = text_full['text'].apply(lambda x: x.split())\n", + "\n", + "v = list(set(text_full.sum())) \n", + "all_word = list(text_full.sum()) \n", + "\n", + "fullword_index = {word: i for i,word in enumerate(v)}\n", + "fullindex_word = {i: word for i,word in enumerate(v)}\n", + "\n", + "data = get_training_data(text,window_size)\n", + "train(data,batch_size)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "NNTI_final_project_task_1.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/Project_files/.ipynb_checkpoints/cnn-task-3-checkpoint.ipynb b/Project_files/.ipynb_checkpoints/cnn-task-3-checkpoint.ipynb new file mode 100644 index 0000000..1877941 --- /dev/null +++ b/Project_files/.ipynb_checkpoints/cnn-task-3-checkpoint.ipynb @@ -0,0 +1,1253 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 62, + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5" + }, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "import torch\n", + "import pickle\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "\n", + "import warnings\n", + "from matplotlib import pyplot as plt\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "from pandas.core.common import SettingWithCopyWarning\n", + "warnings.simplefilter(action=\"ignore\", category=SettingWithCopyWarning)\n", + "\n", + "from torch.utils.data import (TensorDataset, DataLoader, RandomSampler,SequentialSampler)\n", + "\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '6,7'\n", + "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [], + "source": [ + "#Used to return first n items of the iterable as a list\n", + "from itertools import islice\n", + "\n", + "def take(n, iterable):\n", + " return list(islice(iterable, n))" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=================================\n", + "GPU found\n", + "Using GPU at cuda: 0\n", + "=================================\n", + " \n" + ] + } + ], + "source": [ + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", + "if device == 'cuda':\n", + " print(\"=================================\")\n", + " print(\"GPU found\")\n", + " print(\"Using GPU at cuda:\",torch.cuda.current_device())\n", + " print(\"=================================\")\n", + " print(\" \")" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [], + "source": [ + "model1 = torch.load(\"/nn_project/files/hindi/m4665_e300_lr0.05_bs60_es300/epoch_300.pt\")\n", + "\n", + "w1 = model1[\"w1.weight\"].T\n", + "w2 = model1[\"w2.weight\"]\n", + "\n", + "cleandata = pd.read_pickle(\"/nn_project/files/hindi/m4665_e300_lr0.05_bs60_es300/hindi_corpus_cleaned.pkl\")\n", + "word_index = pd.read_pickle(\"/nn_project/files/hindi/m4665_e300_lr0.05_bs60_es300/word_index.pkl\")\n", + "index_word = pd.read_pickle(\"/nn_project/files/hindi/m4665_e300_lr0.05_bs60_es300/index_word.pkl\")\n", + "V = pd.read_pickle(\"/nn_project/files/hindi/m4665_e300_lr0.05_bs60_es300/vocab.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_csv('https://raw.githubusercontent.com/SouravDutta91/NNTI-WS2021-NLP-Project/main/data/hindi_hatespeech.tsv',sep='\\t')\n", + "text = data[['text','task_1']]\n", + "text['text'] = cleandata['text'].apply(lambda x: x.split())\n", + "text['label'] = text['task_1'].apply(lambda x: 1 if x == 'HOF' else 0)\n", + "max_len = text.text.str.len().max()" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [], + "source": [ + "#Function to count the labels in a dataset\n", + "def tag_count(input):\n", + " hcount,ncount = 0,0\n", + " for tag in input:\n", + " if tag == 1:\n", + " hcount+=1\n", + " else:\n", + " ncount+=1\n", + " return hcount,ncount" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [], + "source": [ + "'''Word-index dictionaries are updated with '' word which is used for padding,\n", + "that is to make the sentences uniform in length'''\n", + "\n", + "word_index[''] = len(V)\n", + "index_word[len(V)] = ''\n" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [], + "source": [ + "#Function to get the word embedding from the weight\n", + "def get_word_embedding(input):\n", + " index = word_index[input]\n", + " return w1[index]" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [], + "source": [ + "#Creates the embedding matrix by using the word embeddings and adds zeroes for all the '' words\n", + "def matrix_embeddings():\n", + " _ , emb_size = w1.shape\n", + " embedding_matrix = np.random.uniform(-1, 1, (len(word_index), emb_size))\n", + " embedding_matrix[word_index['']] = np.zeros((emb_size,))\n", + "\n", + " for word,i in take(len(V),word_index.items()):\n", + " temp = get_word_embedding(word)\n", + " if temp is not None:\n", + " embedding_matrix[i] = temp.cpu()\n", + " return embedding_matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [], + "source": [ + "#Encodes the sentences into their respective indexes\n", + "def encode(corpus):\n", + " sent_idx = []\n", + " i = 0\n", + " for sentence in corpus:\n", + " sentence += [''] * (max_len - len(sentence))\n", + " idx = [word_index[word] for word in sentence]\n", + " sent_idx.append(idx)\n", + " i+= 1\n", + " return np.array(sent_idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [], + "source": [ + "#store the encoding,labels and the embeddings\n", + "encoded_text = encode(text.text)\n", + "labels = np.array(text['label'])\n", + "embeds = torch.tensor(matrix_embeddings())" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-0.3367, 0.3636, -0.3046, ..., 0.0300, -0.2679, -0.4148],\n", + " [ 0.0883, 0.0739, 0.1094, ..., -0.1064, -0.0681, -0.2677],\n", + " [-0.2731, 0.1393, 0.2855, ..., 0.0173, 0.2927, 0.2280],\n", + " ...,\n", + " [ 0.2724, -0.0797, -0.0919, ..., -0.3751, 0.1591, 0.1360],\n", + " [-0.0535, 0.1561, -0.0446, ..., 0.0316, -0.3634, -0.2427],\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", + " dtype=torch.float64)" + ] + }, + "execution_count": 73, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeds" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [], + "source": [ + "#split the data into train and test sets with train set being 0.8 and test being the remaining 0.2\n", + "xtrain, xtest, ytrain, ytest = train_test_split(encoded_text,labels,shuffle=True,test_size=0.2,random_state=15)\n", + "\n", + "\n", + "#Creating a dataloader for train and test sets\n", + "def get_dataloader(traindata, testdata, trainlabels, testlabels ,batchsize):\n", + " \n", + " traindata = torch.tensor(traindata).float()\n", + " testdata = torch.tensor(testdata).float()\n", + " trainlabels = torch.tensor(trainlabels)\n", + " testlabels = torch.tensor(testlabels)\n", + " \n", + " train = TensorDataset(traindata,trainlabels)\n", + " train_dataload = DataLoader(train,sampler=RandomSampler(train),batch_size=batchsize,drop_last=True)\n", + " test = TensorDataset(testdata,testlabels)\n", + " test_dataload = DataLoader(test,sampler=RandomSampler(test),batch_size=batchsize,drop_last=True)\n", + "\n", + " return train_dataload,test_dataload\n", + "\n", + "batchsize = 50\n", + "train_dataload,test_dataload = get_dataloader(xtrain, xtest, ytrain, ytest,batchsize)" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The train dataset has 1967 HOF labels and 1765 NOT labels.\n", + "The test dataset has 502 HOF labels and 431 NOT labels.\n" + ] + } + ], + "source": [ + "#Return the label count\n", + "th,tc = tag_count(ytrain)\n", + "testh,testc = tag_count(ytest)\n", + " \n", + "print(\"The train dataset has {} HOF labels and {} NOT labels.\".format(th,tc))\n", + "print(\"The test dataset has {} HOF labels and {} NOT labels.\".format(testh,testc))" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "#Creating the CNN model\n", + "class hindi_cnnmodel(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.word_embed = embeds\n", + " self.filter_sizes = [2, 3, 4] #Size of the Kernel\n", + " self.num_filters = [50, 50, 50] #3 convolution layers each with 50 filters\n", + " self.num_classes=2 #Number of classes\n", + " self.dropout=0.5 #Prevents overfitting\n", + " self.vlen,self.es = self.word_embed.shape\n", + " self.embedding = nn.Embedding.from_pretrained(self.word_embed) #Loading the trained word embedding matrix\n", + "\n", + " #1D convolution is used to detect the features in the sentences. Each filter returns a feature map \n", + " self.conv1d = nn.ModuleList([nn.Conv1d(in_channels=self.es,out_channels=self.num_filters[i],kernel_size=self.filter_sizes[i])\n", + " for i in range(len(self.filter_sizes))])\n", + " \n", + " self.fc = nn.Linear(np.sum(self.num_filters), self.num_classes)\n", + " self.dropout1 = nn.Dropout(p=self.dropout)\n", + "\n", + " def forward(self,input1):\n", + " x_e = self.embedding(input1).float()\n", + " x_r = x_e.permute(0,2,1)\n", + " #ReLU and maxpool is used to reduce the feature map into a single scalar\n", + " #Maxpool will capture the best feature from the feature map\n", + " conv_list = [F.relu(conv(x_r)) for conv in self.conv1d]\n", + " x_maxpool = [F.max_pool1d(x_conv, kernel_size=x_conv.shape[2]) for x_conv in conv_list]\n", + " \n", + " #Fully connected layer\n", + " x_fc = torch.cat([x_pool.squeeze(dim=2) for x_pool in x_maxpool],dim=1)\n", + " drop = self.dropout1(x_fc)\n", + " output = self.fc(drop)\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hindi_cnnmodel(\n", + " (embedding): Embedding(17788, 300)\n", + " (conv1d): ModuleList(\n", + " (0): Conv1d(300, 50, kernel_size=(2,), stride=(1,))\n", + " (1): Conv1d(300, 50, kernel_size=(3,), stride=(1,))\n", + " (2): Conv1d(300, 50, kernel_size=(4,), stride=(1,))\n", + " )\n", + " (fc): Linear(in_features=150, out_features=2, bias=True)\n", + " (dropout1): Dropout(p=0.5, inplace=False)\n", + ")\n" + ] + } + ], + "source": [ + "#Initializing the model\n", + "model = hindi_cnnmodel()\n", + "\n", + "#Sending the model to GPU\n", + "model.cuda()\n", + "\n", + "print(model)\n", + "\n", + "#Setting the paraters\n", + "learning_rate = 0.05\n", + "epochs = 10\n", + "\n", + "#We use SGD optimizer and CrossEntropyLoss\n", + "optimizer = optim.SGD(model.parameters(),lr=learning_rate)\n", + "criterion = nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [], + "source": [ + "def train(train_dataload,test_dataload):\n", + " accuracy = 0\n", + " print(\"Training started\")\n", + " start = time.time()\n", + " losses = []\n", + " test_loss = []\n", + "\n", + " for epoch in range(epochs):\n", + " totloss = 0\n", + " model.train()\n", + " \n", + " #Take a batch at a time from the dataloader\n", + " for i,batch in enumerate(train_dataload):\n", + " #Send the input and the label to GPU \n", + " x_train,y_train = tuple(x.to(torch.int64).cuda() for x in batch)\n", + " print(type(x_train))\n", + " #Compute the loss\n", + " model.zero_grad() #Make previous calculated gradients zero\n", + " output = model(x_train)\n", + " loss = criterion(output,y_train)\n", + " totloss += loss.item() \n", + " loss.backward() #Compute gradients\n", + " optimizer.step() #updates weights\n", + " \n", + " # Calculate the average loss over the entire training data\n", + " train_loss = totloss/len(train_dataload)\n", + " losses.append(train_loss) \n", + " \n", + " testloss,testacc = test_acc(test_dataload)\n", + " test_loss.append(testloss)\n", + " \n", + " if testacc > accuracy:\n", + " accuracy = testacc\n", + " \n", + " print(\"At epoch {} the training loss is {}, the test loss is {} and accuracy of {}%\".format(epoch,round(testloss,3),round(train_loss,3),round(accuracy,2)))\n", + " \n", + " #Plotting the train loss and the test loss at the end\n", + " plt.plot(losses,label=\"train\")\n", + " plt.plot(test_loss,label=\"test\")\n", + " plt.xlabel(\"EPOCHS\")\n", + " plt.ylabel(\"LOSS\")\n", + " plt.legend()\n", + " plt.show()\n", + " print(\"Training ended\")" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "def test_acc(test_dataload):\n", + " #By putting the model into evaluation mode, the dropout layers are stopped for the time being\n", + " model.eval()\n", + " testacc = []\n", + " testloss = []\n", + "\n", + " for batch in test_dataload:\n", + " x_test,y_test = tuple(x.to(torch.int64).cuda() for x in batch)\n", + "\n", + " with torch.no_grad():\n", + " output = model(x_test) #Computing logits\n", + " #Calculate the test loss\n", + " loss = criterion(output,y_test)\n", + " testloss.append(loss.item())\n", + " \n", + " #Calculate the predictions and its accuracy\n", + " preds = torch.argmax(output,dim=1).flatten()\n", + "\n", + " accu = (preds == y_test).cpu().numpy().mean() * 100\n", + " testacc.append(accu)\n", + " \n", + " x = np.mean(testloss)\n", + " y = np.mean(testacc)\n", + "\n", + " return x,y" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training started\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 0 the training loss is 0.565, the test loss is 0.639 and accuracy of 77.44%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 1 the training loss is 0.482, the test loss is 0.528 and accuracy of 78.11%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 2 the training loss is 0.45, the test loss is 0.451 and accuracy of 78.78%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 3 the training loss is 0.436, the test loss is 0.41 and accuracy of 79.56%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 4 the training loss is 0.423, the test loss is 0.376 and accuracy of 79.89%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 5 the training loss is 0.42, the test loss is 0.347 and accuracy of 80.56%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 6 the training loss is 0.426, the test loss is 0.317 and accuracy of 80.56%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 7 the training loss is 0.415, the test loss is 0.295 and accuracy of 81.0%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 8 the training loss is 0.414, the test loss is 0.266 and accuracy of 81.11%\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "At epoch 9 the training loss is 0.423, the test loss is 0.251 and accuracy of 81.11%\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEGCAYAAAB/+QKOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAA2OElEQVR4nO3deXhV5bX48e/KTCBAgDCGkDATQKYwKCAoICAKWocqzqDYXq36U7nF1lbFetV6a9WrtVVBrBMOrYoKClZwQFHCJCRMIQxJmMIQ5szr98fewCEmBJJzsjOsz/OcJ2dP71454ll5h/2+oqoYY4wxJQV5HYAxxpjqyRKEMcaYUlmCMMYYUypLEMYYY0plCcIYY0ypQrwOwF+aNWum8fHxXodhjDE1yrJly/aoakxpx2pNgoiPjyc5OdnrMIwxpkYRka1lHQtoE5OIjBGR9SKSJiLTyjjnahFJFZEUEXnLZ3+RiKx0X3MCGacxxpifC1gNQkSCgReAUUAmsFRE5qhqqs85nYAHgMGqul9EmvsUcUxVewcqPmOMMacXyBrEACBNVdNVNR+YDUwocc5twAuquh9AVXcHMB5jjDFnIZB9EG2ADJ/tTGBgiXM6A4jIYiAYeFhVP3OPRYhIMlAIPKGqH5a8gYhMAaYAxMXF+TV4Y0zdUFBQQGZmJrm5uV6HElARERHExsYSGhp6xtd43UkdAnQChgOxwNci0lNVc4B2qpolIu2BL0Vktapu8r1YVV8CXgJISkqySaWMMWctMzOTqKgo4uPjERGvwwkIVWXv3r1kZmaSkJBwxtcFsokpC2jrsx3r7vOVCcxR1QJV3QxswEkYqGqW+zMdWAT0CWCsxpg6Kjc3l6ZNm9ba5AAgIjRt2vSsa0mBTBBLgU4ikiAiYcA1QMnRSB/i1B4QkWY4TU7pIhItIuE++wcDqRhjTADU5uRwXEV+x4AlCFUtBO4EPgfWAu+qaoqITBeR8e5pnwN7RSQVWAhMVdW9QDcgWURWufuf8B395E9H8gr582fr2Lb3aCCKN8aYGiugz0Go6lxV7ayqHVT1MXffH1V1jvteVfVeVU1U1Z6qOtvd/5273cv9OSNQMR7KLWTWd1uY/klKoG5hjDFlysnJ4W9/+9tZX3fxxReTk5Pj/4B81Pm5mFo2iuCuEZ34Yu1uvly3y+twjDF1TFkJorCw8LTXzZ07l8aNGwcoKkedTxAAkwYn0CGmPo98nEpuQZHX4Rhj6pBp06axadMmevfuTf/+/Rk6dCjjx48nMTERgMsuu4x+/frRvXt3XnrppRPXxcfHs2fPHrZs2UK3bt247bbb6N69OxdddBHHjh3zS2xeD3OtFsJCgnh4fHdumPEjL3+dzm9GdPI6JGOMBx75OIXU7Qf9WmZi64Y8dGn3Mo8/8cQTrFmzhpUrV7Jo0SLGjRvHmjVrTgxHnTlzJk2aNOHYsWP079+fK664gqZNm55SxsaNG3n77bd5+eWXufrqq/nXv/7F9ddfX+nYrQbhGtophrE9WvLCojQy91uHtTHGGwMGDDjlWYXnnnuOXr16MWjQIDIyMti4cePPrklISKB3794A9OvXjy1btvglFqtB+HjwkkQWrt/Nnz5Zy99v6Od1OMaYKna6v/SrSv369U+8X7RoEV988QXff/89kZGRDB8+vNRnGcLDw0+8Dw4O9lsTk9UgfLRpXI87L+jIZyk7+XpDttfhGGPqgKioKA4dOlTqsQMHDhAdHU1kZCTr1q1jyZIlVRqbJYgSbju/PfFNI3n44xTyC4u9DscYU8s1bdqUwYMH06NHD6ZOnXrKsTFjxlBYWEi3bt2YNm0agwYNqtLYRLV2TGGUlJSk/lowaOG63dwyaynTxnblV8M6+KVMY0z1tHbtWrp16+Z1GFWitN9VRJapalJp51sNohQXdG3OyG4teO4/G9lxwD9tecYYU9NYgijDQ5cmUlisPPbpWq9DMcYYT1iCKEPbJpH8elgHPvlpB99t2uN1OMYYU+UsQZzGr4d3IDa6Hg99lEJBkXVYG2PqFksQpxERGswfL0lk4+7DvPbdFq/DMcaYKmUJohyjElswvEsMz3yxkd0Ha/eShMYY48sSRDlEhIcu7U5+YTFPzFvndTjGmFqmotN9AzzzzDMcPRq4qYEsQZyBhGb1ue38BP69IoulW/Z5HY4xphapswlCRMaIyHoRSRORaWWcc7WIpIpIioi85bP/JhHZ6L5uCmScZ+KOCzrSulEEf/hwDYXWYW2M8RPf6b6nTp3KU089Rf/+/TnnnHN46KGHADhy5Ajjxo2jV69e9OjRg3feeYfnnnuO7du3c8EFF3DBBRcEJLaATdYnIsHAC8AoIBNYKiJzfJcOFZFOwAPAYFXdLyLN3f1NgIeAJECBZe61+wMVb3kiw0J48JJE/uvN5bz5wzZuOi/eq1CMMYEybxrsXO3fMlv2hLFPlHnYd7rv+fPn8/777/Pjjz+iqowfP56vv/6a7OxsWrduzaeffgo4czQ1atSIp59+moULF9KsWTP/xuwKZA1iAJCmqumqmg/MBiaUOOc24IXjX/yqutvdPxpYoKr73GMLgDEBjPWMjO3RkiEdm/G/89ez53Ce1+EYY2qZ+fPnM3/+fPr06UPfvn1Zt24dGzdupGfPnixYsIDf/va3fPPNNzRq1KhK4gnkdN9tgAyf7UxgYIlzOgOIyGIgGHhYVT8r49o2JW8gIlOAKQBxcXF+C7wsIsLD4xMZ88w3/Pmzdfz5yl4Bv6cxpgqd5i/9qqCqPPDAA9x+++0/O7Z8+XLmzp3Lgw8+yIgRI/jjH/8Y8Hi87qQOAToBw4FrgZdFpPGZXqyqL6lqkqomxcTEBCbCEjo2j2LykATeTc5k+TbPWryMMbWE73Tfo0ePZubMmRw+fBiArKwsdu/ezfbt24mMjOT6669n6tSpLF++/GfXBkIgaxBZQFuf7Vh3n69M4AdVLQA2i8gGnISRhZM0fK9dFLBIz9JvRnTiw5VZPPRRCh/eMZjgIPE6JGNMDeU73ffYsWOZOHEi5557LgANGjTgjTfeIC0tjalTpxIUFERoaCgvvvgiAFOmTGHMmDG0bt2ahQsX+j22gE33LSIhwAZgBM4X/lJgoqqm+JwzBrhWVW8SkWbACqA3bsc00Nc9dTnQT1XLHGPqz+m+z8RHK7O4e/ZKHru8B9cNbFdl9zXG+JdN9+3BdN+qWgjcCXwOrAXeVdUUEZkuIuPd0z4H9opIKrAQmKqqe91E8ChOUlkKTD9dcvDC+F6tGZjQhKc+X8/+I/leh2OMMX4X0D4IVZ2rqp1VtYOqPubu+6OqznHfq6req6qJqtpTVWf7XDtTVTu6r1cDGWdFiAjTJ/TgUG4hT81f73U4xhjjd153UtdoXVpGcdO58bz94zZ+yszxOhxjTAXVlpU1T6civ6MliEq6Z1QnmtYP548fpVBcXPv/kRlT20RERLB3795anSRUlb179xIREXFW1wVyFFOd0DAilAfGduW+91bx/rJMru7ftvyLjDHVRmxsLJmZmWRnZ3sdSkBFREQQGxt7VtdYgvCDy/u04a0ft/HkZ+sY3b0ljSJDvQ7JGHOGQkNDSUhI8DqMasmamPwgKEiYPqE7+4/m8/QC67A2xtQOliD8pHvrRlw/qB2vL9lK6vaDXodjjDGVZgnCj+4b1YXGkWE8NGdNre7wMsbUDZYg/KhRZCi/HdOFpVv288GKkrOKGGNMzWIJws+u6teWXm0b8z9z13Eot8DrcIwxpsIsQfhZUJAwfXx39h7J45kvNnodjjHGVJgliADo1bYx1/Rvy6zvtrBhV+Cm4jXGmECyBKEKy/8Jx3L8WuzU0V1pEB7CHz+yDmtjTM1kCWLPRvjk/8G/boXiIr8V26R+GPeP7sKS9H188tMOv5VrjDFVxRJETGcY+2dIWwBfPurXoicOiKNHm4Y89ulajuQV+rVsY4wJNEsQAP0nQ7+b4du/wpp/+a3Y4CDhkfE92Hkwl+e+tA5rY0zNYgniuLFPQdtB8OEdsOMnvxXbr100V/aLZea3m0nbfdhv5RpjTKBZgjguJAyu/ifUi4bZ18GRPX4r+rdjuhIRGswjH6dYh7UxpsYIaIIQkTEisl5E0kRkWinHbxaRbBFZ6b5u9TlW5LN/TiDjPCGqBVzzJhzeBe/dDEX+edAtJiqce0d15puNe/g8ZadfyjTGmEALWIIQkWDgBWAskAhcKyKJpZz6jqr2dl+v+Ow/5rN/fCnXBUabvjD+OdjyDcx/0G/F3jCoHV1bRvHoJ2s5lu+/0VLGGBMogaxBDADSVDVdVfOB2cCEAN7Pf3pdA4PugB/+Dive8EuRIcFBPDK+O1k5x/jbojS/lGmMMYEUyATRBsjw2c5095V0hYj8JCLvi4jvcmwRIpIsIktE5LLSbiAiU9xzkv2+GtSo6ZAwzHlGIjPZL0UObN+Uy3q35h9fpbNlzxG/lGmMMYHidSf1x0C8qp4DLABe8znWTlWTgInAMyLSoeTFqvqSqiapalJMTIx/IwsOgatmQVQreOd6OOSfvoPfXdyN0GCxDmtjTLUXyASRBfjWCGLdfSeo6l5VzXM3XwH6+RzLcn+mA4uAPgGMtXSRTeCatyD3gJMkCvPKv6YczRtGcM/Izixcn81/1u72Q5DGGBMYgUwQS4FOIpIgImHANcApo5FEpJXP5nhgrbs/WkTC3ffNgMFAagBjLVvLHnDZi5C5FD69z5m7qZJuHhxPx+YNeOSTFHILrMPaGFM9BSxBqGohcCfwOc4X/7uqmiIi00Xk+Kiku0QkRURWAXcBN7v7uwHJ7v6FwBOq6k2CAOh+GQy9H1a8DktfKff08oQGBzF9fHcy9h3jH1+lVz4+Y4wJAKkt7eBJSUmanOyfzuRSFRfD7Gsh7Qu48SOIH1LpIu94azlfpO7ii3uH0bZJpB+CNMaYsyMiy9z+3p/xupO65ggKgl+8BNEJ8O6NkLOt0kX+/uJuBInw6CfeVY6MMaYsliDORkQjuPZt5wnr2ddB/tFKFde6cT1+M6Ij81N3sWi9dVgbY6oXSxBnq1knuGIG7FwNc35T6U7ryUMSaN+sPo98nEpeoXVYG2OqD0sQFdH5IhjxB1jzPnz3XKWKCg8J5qHx3dm85wivfLPZTwEaY0zlWYKoqCH3QuJl8MXDTsd1JQzrHMPo7i14/ss0NtsT1saYasISREWJwGV/g+aJ8P4k2LupUsX94ZJEwkODmDRrKfuP5PspSGOMqThLEJURVt+ZHlyCYfZEyDtU4aJioyN5+cYksvYf4/bXl1l/hDHGc5YgKis63pmzac9G+PftzvMSFdQ/vgn/e3Uvftyyj6nv/URxce14RsUYUzNZgvCH9sNg9GOw/lP46slKFTW+V2umju7CnFXbeXrBBj8FaIwxZy/E6wBqjYG/ctay/uoJZ/6mbpdWuKj/Gt6BbXuP8vzCNOKaRHJ1/7blX2SMMX5mNQh/EYFL/gpt+sEHv4LdaytRlPCny3swtFMzfvfBar7d6L/1sY0x5kxZgvCn0Aj45RtO5/XsiXBsf8WLCg7ihev60iGmAb9+Yxnrd1a8A9wYYyrCEoS/NWwNV78OORnO8Nfiio9GahgRysxb+hMRFsykWUvZfTDXj4EaY8zpWYIIhLiBMO4vsOlL50G6SmjTuB4zb+rPviP5TH4tmaP5hf6J0RhjymEJIlD63QT9b3Wm4vjpvUoV1TO2Ef93bR9Sth/grrdXUmTDX40xVcASRCCNfhzizoM5d8L2lZUqamRiCx66tDtfrN1l04MbY6pEQBOEiIwRkfUikiYi00o5frOIZIvISvd1q8+xm0Rko/u6KZBxBkxIGFz9T4hs5kwPfji7UsXddF48kwYnMOu7Lby62Cb2M8YEVsAShIgEAy8AY4FE4FoRSSzl1HdUtbf7esW9tgnwEDAQGAA8JCLRgYo1oBrEwDVvwNE98N5NzloSlfD7cd24KLEF0z9JZX7KTj8FaYwxPxfIGsQAIE1V01U1H5gNTDjDa0cDC1R1n6ruBxYAYwIUZ+C17gPjn4eti+Gzn1WkzkpwkPDsNX04p00j7p69kp8yc/wTozHGlBDIBNEGyPDZznT3lXSFiPwkIu+LyPFHhs/oWhGZIiLJIpKcnV255puAO+cqOO83sPQVWPZapYqqFxbMKzf1p0n9MCbNSiZzf+VWtjPGmNJ43Un9MRCvqufg1BLO6ptTVV9S1SRVTYqJiQlIgH418hHocCF8eh9s+6FSRcVEhTPrlv7kFRYxadZSDuZWrunKGGNKCmSCyAJ8JxGKdfedoKp7VTXP3XwF6Hem19ZIQcFw5UxoFAvv3gAHt1equE4tovjH9f1Izz7Cr99YRn5hxWeSNcaYkgKZIJYCnUQkQUTCgGuAOb4niEgrn83xwPEJjD4HLhKRaLdz+iJ3X81XLxqufRvyj8A710NB5Z6OPq9jM5644hwWp+3l9x+sRiu5RrYxxhwXsAShqoXAnThf7GuBd1U1RUSmi8h497S7RCRFRFYBdwE3u9fuAx7FSTJLgenuvtqheTe4/B+QtQw+vRcq+aV+Zb9Y7hrRifeWZfLCwjQ/BWmMqeuktvzFmZSUpMnJyV6HcXYWPu5MDz7mSRj0q0oVparc++4qPliRxbPX9GZC79LGAxhjzKlEZJmqJpV2zOtO6rpt2G+hyzj4/HeQ/lWlihIRnriiJwMSmjD1vZ/4cXPtqXAZY7xhCcJLQUFw+d+haUd472bYv6VSxYWHBPPSDf2IbVKPKa8nk5592C9hGmPqJksQXoto6HRaa5EzHcf+rZUqrnFkGLNuHkCwCLfMWsrew3nlX2SMMaWwBFEdNO3gDH/dsxH+ry98dAfsS69wcXFNI3n5piR2Hsjltn8mk1tQ8TUpjDF1lyWI6qLjSLhrBSRNhtXvw/8lwb9vd5JGBfSNi+aZX/ZmRUYO9727imKbItwYc5YsQVQnjdrAxX+Gu1fBoF9D6kfwfH94f3KF1rge27MVD4ztyqerd/Dnz9cHIGBjTG1mCaI6imoJox+De1bD4Lth/Tz427nw7o2wc/VZFXXb0PZcNzCOv3+1ibd+2BaggI0xtZEliOqsQQyMesRJFEPvg00L4e9D4O2JsH3FGRUhIjwyvjvDu8Twh4/WsGj97gAHbYypLSxB1AT1m8KIP8A9P8HwB2Drt/DScHjzKshYWu7lIcFBPD+xL11aRHHnWytYu+Ng4GM2xtR4liBqknrRMHwa3LMGLvwDZCbDjJHw+uWw9fvTXtogPISZN/enQXgIk2YtZeeBys0BZYyp/SxB1EQRDeH8+52mp1HTYcdP8OoYmHUJbP6mzLmdWjaKYObN/Tl4rIBJs5ZyOK+wigM3xtQkliBqsvAGTif2Path9P/Ang3w2iXw6ljY9GWpiSKxdUNeuK4v63cd4jdvLaewyKYIN8aUzhJEbRAWCefe4QyPHfsU5Gxzmp1eGQkb5v8sUQzv0pzpE7qzcH02j3ycalOEG2NKddoEISKXikg7n+0/isgqEZkjIgmBD8+cldB6MHCK88DdJX+Fw7vhraucDu11n56SKK4b2I7bh7Xn9SVbmfHtZu9iNsZUW+XVIB4DsgFE5BLgemASzsI/fw9saKbCQsIhaRLctRzGPw+5OTB7Ivx9KKR8CMVOs9JvR3dlXM9WPDZ3LfNW7/A0ZGNM9VNeglBVPeq+/wUwQ1WXqeorQA1YBLqOCw6FvjfAncucBYoKj8F7N8GL58Hq9wmimL9c3Ys+bRtzzzsrWbFtv9cRG2OqkfIShIhIAxEJAkYA//E5FhG4sIxfBYdAr2vgjh/hihmAwr8mwwsDiUh9j5ev702LhhHc+loy2/YeLbc4Y0zdUF6CeAZYCSQDa1U1GUBE+gDltkmIyBgRWS8iaSIy7TTnXSEiKiJJ7na8iBwTkZXuy5qz/CEoGHpeCb/+Hq56zWmK+uB2mr46mPcHboKiAm6Z9SMHjhZ4Hakxphood8lREWkDNAdWqWqxu68lEKaqZU7uIyLBwAZgFJCJs7b0taqaWuK8KOBTIAy4U1WTRSQe+ERVe5zpL1Ijlxz1WnExbJgHXz0JO1aRV78Nfzo4lk2tx/PMdQNp3tAqicbUdhVectQdwXRYVVeoarGIXCAizwITgZ3l3HcAkKaq6aqaD8wGJpRy3qPAk4A92lvVgoKg6ziY8hVMfI/wxq14NPgVnt95HSl/uZiVr0+jaP3ncDjb60iNMR4IKef4u8DlwAER6Q28BzwO9AL+Btx6mmvbABk+25nAQN8TRKQv0FZVPxWRqSWuTxCRFcBB4EFV/abkDURkCjAFIC4urpxfxZRJBDpfBJ1GwaYvCV36Fp02LaV12t8J2vSic07DWGjdG1r3OfmKbOJp2MaYwCovQdRT1e3u++uBmar6F7fTemVlbuyW8TRwcymHdwBxqrpXRPoBH4pId1U9ZZY5VX0JeAmcJqbKxGNwEkXHEUR1HEEDVeYuS+PDefNol7eeCcG7SNyVSvC6T06e37jdqQmjVS+o19iz8I0x/lVeghCf9xcCDwC4zU3llZ0FtPXZjnX3HRcF9AAWuWW1BOaIyHi3MzzPvdcyEdkEdMbpLDdVQEQYl9SJId3j+cv89YxfspWYBuFMvySW0U12IttXOlOOb18BqR+evLBJhxJJ4xwIj/Lq1zDGVMJpO6nd/oZWOH/Rjwc6q2qBiLQCPi6rY8O9NgSnk3oETmJYCkxU1ZQyzl8E3O92UscA+1S1SETaA98APVV1X1n3s07qwFqVkcPvPlhNyvaDnN85hkcndKdd0/rOwaP7TiaL7Stg+0o4mOleKdCs86lJo2VPZ3oQY4znTtdJXV6CEOCXOEniXVXNcvf3AZqr6ufl3PhinKGywTjNU4+JyHQgWVXnlDh3EScTxBXAdKAAKAYeUtWPT3cvSxCBV1hUzOtLtvKX+RsoKCrmzgs6MmVYe8JDgn9+8uHdTqI4kTSWw+FdzjEJgphubsLoDa37QovuEGqjpoypahVOED4FJADd3c1UVU33Y3x+YQmi6uw8kMujn6Ty6eodtI+pz58u68F5HZqVf+HBHSVqGsvh6F7nWFAINE88tabRPBFCwgL7yxhTx1WmBtEQeAXoB6xyd/cGlgGTS3Yae8kSRNVbuH43f/xoDRn7jvGLPm343bhuNGsQfuYFqMKBzBJJY4UzdxRAcBi06HEyYbTpC826OE+GG2P8ojIJYhawBZju85CcAH8AOqrqjX6PtoIsQXgjt6CI579M4x9fbyIyLITfjunKNf3bEhRU7iCG0qnC/i1O7eJ4E9WOVZDn/i0SUs/p+PataTTt6Dwlbow5a5VJEBtVtdPZHvOCJQhvpe0+xIMfrmFJ+j76xjXmT5f1JLF1Q/8UXlwM+9JPbZrasQoK3Hmjwho4Q2x9k0Z0gvMgoDHmtAKVINJUtaOfYqw0SxDeU1U+WJHFY5+uJedYAZMGx3PPyM7UDw9Ak1BxkbOCnm/T1M7VUOg+kB/eCFqXSBqN2znPehhjTqhMgngN2AQ8qj4nisgfcIa83uDvYCvKEkT1kXM0nyc/W8/bP26jVaMIHh7fnYsSW3AGz85UTlEBZK8rkTTWQLE7+WC96FMTRus+0LCNJQ1Tp1W2k3oG0JeTT073BlbgdFIf8GuklWAJovpZtnUfv/9gDet2HmJkt+Y8PL47sdFV/PxDYR7sTj2ZMLJWONta5Byv3/znSSOqRdXGWB2oOv08R/Y4r6N74Ei2+9p78n3uAWjc1hmm3LyrM9KsSQcbbVaD+WOYawcg0d1MVdVNInKPqj7jvzArxxJE9VRQVMyrizfz1wUbAbh7ZCcmD0kgNNjD/oGCY07NwremsWc9OOMwIKp1iaapts5yriH1nGc1QurVjJFU+UdOfuEfyfb50vfZdyTbGWp8JBuK8ksvJywK6jeD+jHOU/E5W50+oeOfV1CIM1Agxk0Yzbs6CaRJ+5rxOdVUxcVwaIczqEOLIWFohYqpdIIoo9BtqlptZsizBFG9ZeUc45E5KcxP3UXnFg147PKe9I+vRpP95R12+jB8k8bejWWfHxQCoZEQEnEyaYRG+Oyrd/Kn7/tT9tU7w2vrOR3uBbnul3x5X/puDaCgjMWfQiOdL/xI90u/fgzUb+r8PLGv2clzSnuAsSDX6QPKXge717o/U2H/VsD9TgkOc56ij+l6srYR0xWi423U2ZnKP+IkgJKvfZudRH08qbfuA1MWVegWgUoQGaratvwzq4YliJphQeouHp6TQlbOMa5OimXa2G40qV9NmydyD8COn5wnwAtznZpHwbGT7wtznS/hglxnOdcTP4+Vse8YJ748z1ZQ6Mm+lJKCw09+oZ/4kvfZLrkvrH6FP5Jy5R91amO71/okjnVwwGfpmJAIJ3H41jaad4NGbeveyLPiYji88+df/sffH9l96vnhDZ0EW/LVtIPzswKsBmGqlaP5hTz7n43M+GYzUREhPHBxN67qFxv4TmyvqTp/8VUkuRTmOsN5T/x1HwORTU82+1T3zy7vEGSvP7W2sXsdHNp+8pzQ+hDTxUkWzbud7Oeo6QMJ8o86f+2X/PLfv8XZf3zkHTjT0DSMheh2zhd+kwSfRJDgDLTw82dRmU7qQ5T+J4/gTAVebRoYLUHUPOt2HuTBD9aQvHU/A+Kb8KfLe9C5hc38WqccyynRTOXWPHz/cg5veLKZ6nhto2kHp+YkQU5zlYjzXoJ99gX5vAKYYFSdWmbJL//97vbxOciOC4uCJvGl1AQSnFpUFXf4B6QGUd1YgqiZiouV95Zl8Pi8dRzOLeS289tz14WdqBdmbdR12tF9btJwE8budc7743N3nTUpkTSOJ5KgUvYFn0wqpSYcNyEFBbu1g21OTc/3Xo1i3S/9ds4X//EEEB3vLLRVjWpEliBMtbf3cB6Pz1vH+8syiY2ux/QJ3bmwax0cbmpO73C22xm+BYoLndE7Wuw8OHn8vbrvi4t/vu9n55bcd/ynlrLv+Hl6cl9w2Kk1gCYJTnIIOYs5yTxmCcLUGEvS9/Lgh2tI232Ykd1aMHV0F7q0tGYnYwLldAmijg0ZMNXdoPZNmXvXUP57TBeWpO9lzLNfc/fsFWzec8Tr0Iypc6wGYaqt/Ufy+cfX6cz6bjMFRcqVfWO5a2Qn2jSu53VoxtQantUgRGSMiKwXkTQRmXaa864QERWRJJ99D7jXrReR0YGM01RP0fXDmDa2K1//9wXcMKgdH6zI4oKnFvHQR2vYfTC3/AKMMZUSsBqEiATjrEk9CsjEWZP6WlVNLXFeFPApEAbc6S45mgi8DQwAWgNf4EwOWFTW/awGUfttzznG/32ZxnvJGQQHCTedF8+vhnWovg/aGVMDeFWDGACkqWq6quYDs4EJpZz3KPAk4Psn4QRgtqrmqepmIM0tz9RhrRvX4/Ff9OQ/9w1jXM9WvPxNOkOf/JKn56/nwLEynjI2xlRYIBNEGyDDZzvT3XeCiPQF2qrqp2d7rXv9FBFJFpHk7Oxs/0Rtqr12Tevz9C97M/+e8xnepTnPfZnG+X9eyAsL0ziSV+h1eMbUGp6NYhKRIOBp4L6KlqGqL6lqkqomxcTE+C84UyN0ahHFC9f15ZPfDCGpXTRPfb6eYU8tZMa3m8ktKLM10hhzhgKZILIA38n8Yt19x0UBPYBFIrIFGATMcTuqy7vWmBN6tGnEjJv7869fn0eXllE8+kkqw59axBtLtpJfWOx1eMbUWIHspA7B6aQegfPlvhSYqKopZZy/CLjf7aTuDrzFyU7q/wCdrJPanInvNu3hL/M3sGzrfto2qcfdIzpzWe/WhHi5BoUx1ZQnndSqWgjcCXwOrAXeVdUUEZkuIuPLuTYFeBdIBT4D7jhdcjDG13kdmvH+r87l1Vv606heKPe/t4qLnvmaj1dtp7i4djz3Y0xVsAflTK2mqnyesounF6xnw67DdG0ZxX0XdWFkt+a1f3pxY86ATbVh6iwRYUyPlsy7+3yevaY3uQVF3PbPZC7723d8szGb2vIHkjGBYAnC1AnBQcKE3m1YcO8wnryiJ3sO5XHDjB/55UtLWLpln9fhGVMtWROTqZPyCot4Z2kG//dlGtmH8ji/cwz3X9SZc2Ibex2aMVXKpvs2pgzH8ot4fckWXly0if1HC7gosQX3XtSZri0beh2aMVXCEoQx5TicV8jMbzfz8tfpHM4v5NJzWnPPyE60j2ngdWjGBJQlCGPOUM7RfF76Op1XF28hr7CIK/rG8psLOxHXNNLr0IwJCEsQxpylPYfzeHHRJl5fspXComJGJbZg8pD29I+PtuGxplaxBGFMBe06mMs/v9/Cmz9sI+doAT3bNGLSkHjG9WxNWIgNAjQ1nyUIYyrpWH4RH6zIYubizaTtPkzzqHBuPLcdEwe2s/UoTI1mCcIYPykuVr7emM3MxVv4ekM24SFB/KJvGyYNTqBTiyivwzPmrJ0uQYRUdTDG1GRBQcLwLs0Z3qU5G3cdYubiLfx7eSZv/5jB0E7NmDwkgWGdY6yfwtQKVoMwppL2HcnnrR+28s/vt7L7UB4dmzfglsHx/KJPLPXCgr0Oz5jTsiYmY6pAfmExn67ezoxvN7Mm6yCNI0OZOCCOG8+Np2WjCK/DM6ZUliCMqUKqytIt+5nxbTrzU3cRLMK4c1oxeUiCTeVhqh3rgzCmCokIAxKaMCChCdv2HuW177fwztIMPlq5naR20UweksBF3VsSHGT9FKZ6sxqEMVXgUG4B7yZnMuu7zWTsO0ZsdD1uPi+eq/u3pWFEqNfhmTrMsyYmERkDPAsEA6+o6hMljv8KuAMoAg4DU1Q1VUTicVahW++eukRVf3W6e1mCMDVBUbGyIHUXMxdv5sfN+6gfFsxVSW25ZXA87ZrW9zo8Uwd5kiBEJBhnTepRQCbOmtTXqmqqzzkNVfWg+3488F+qOsZNEJ+oao8zvZ8lCFPTrM48wMzFm/nkp+0UFisju7Vg8pAEBiY0sWGypsp41QcxAEhT1XQ3iNnABJx1pgE4nhxc9YHa0d5lzBnoGduIv/6yN9PGduX177fy5g9bWZC6i+6tGzJpcAKX9GpFeIgNkzXeCeRkMm2ADJ/tTHffKUTkDhHZBPwZuMvnUIKIrBCRr0RkaADjNMZTLRpGcP/oLnz/wAge/0VP8guLue+9VQx5ciHP/Wcjew/neR2iqaMC2cR0JTBGVW91t28ABqrqnWWcPxEYrao3iUg40EBV94pIP+BDoHuJGgciMgWYAhAXF9dv69atAfldjKlKqso3G/cwc/FmFq3PJiwkiMt7t2Hy0AQ623Qexs+86oM4F3hYVUe72w8AqOrjZZwfBOxX1UalHFsE3K+qZXYyWB+EqY3Sdp+cziO3oJhhnWO4bWh7Bndsav0Uxi9OlyAC2cS0FOgkIgkiEgZcA8wpEVgnn81xwEZ3f4zbyY2ItAc6AekBjNWYaqlj8yj+5/KefD9tBPeN6kzK9oNcP+MHLn7uW/61LJP8wmKvQzS1WKCHuV4MPIMzzHWmqj4mItOBZFWdIyLPAiOBAmA/cKeqpojIFcB0d38x8JCqfny6e1kNwtQFuQVFzFm5nZe/SWfj7sO0aBjOzeclMHFAHI0i7XkKc/Zsqg1jahlV5asN2bz8TTqL0/YSGRbM1UltmTwkgbZNbHlUc+YsQRhTi6VsP8CMbzYzZ9V2ilUZ26MVtw5NoE9ctNehmRrAEoQxdcCOA8eY9d0W3vphG4dyC0lqF82tQ9szKrGFzftkymQJwpg65HBeIe8uzWDm4s1k7j9GfNNIJg9J4Mp+bW19CvMzliCMqYMKi4r5PGUXL32TzqqMHBpHhnL9wHbceF47mkfZ+hTGYQnCmDpMVUneup+Xv05nwdpdhAYFcVmf1tw6tL09eGdsPQhj6jIRoX98E/rHN2HzniPM/HYz7y3L4N3kTHvwzpyW1SCMqYP2HcnnzSVbee37rew5nEe3Vg25dUgCl/ZqTVhIIJ+fNdWNNTEZY0plD94ZSxDGmNNSVRZtyOYVe/CuzrEEYYw5Y/bgXd1iCcIYc9bKevBuRLfmhAZbP0VtYQnCGFNhxx+8m/HtZrJyjtGoXigXJbZgbM+WDO7YzFa9q+EsQRhjKq2wqJiF67OZt3oHC1J3cSivkKjwEEYmtmBMj5YM6xxDRKgli5rGnoMwxlRaSHAQoxJbMCqxBXmFRXyXtpe5q3ewYO0uPliRRWRYMBd2bc7YHq24oGsMkWH29VLT2X9BY8xZCw8J5oKuzbmga3MKiopZkr6Xuat3Mj9lJ5/8tIOI0CCGd27O2J4tubBrc6IibMhsTWRNTMYYvykqVn7cvI/P1uxg3pqd7D6UR1hwEEM7NWNsz1aM6tbCnq+oZqwPwhhT5YqLleXb9jNvzU7mrd7B9gO5hAQJgzs2Y2yPllzUvSVN6od5HWad51mCEJExwLM4S46+oqpPlDj+K+AOoAg4DExR1VT32APAZPfYXar6+enuZQnCmOpLVVmVeYB5a3Ywb/VOtu07SnCQMKh9E8b0aMXo7i1shlmPeJIgRCQY2ACMAjKBpcC1xxOAe05DVT3ovh8P/JeqjhGRROBtYADQGvgC6KyqRWXdzxKEMTWDqpKy/SCfrdnJ3DU7SM8+ggj0b9eEsT1bMqZHS1o1qud1mHWGV6OYBgBpqpruBjEbmACcSBDHk4OrPnA8W00AZqtqHrBZRNLc8r4PYLzGmCogIvRo04gebRpx30Wd2bj7MHNXOzWLRz5O5ZGPU+kT15iLe7RiTI+WNtWHhwKZINoAGT7bmcDAkieJyB3AvUAYcKHPtUtKXNumlGunAFMA4uLi/BK0MabqiAidW0TRuUUU94zszKbsw07NYvUOHpu7lsfmrqVnm0aM7dmSsT1akdCsvtch1ymBbGK6Ehijqre62zcAA1X1zjLOnwiMVtWbROR5YImqvuEemwHMU9X3y7qfNTEZU7ts23uUeWt2MHfNTlZl5ADQtWUUF/dsxcU9W9GxeQNvA6wlvGpiygLa+mzHuvvKMht4sYLXGmNqmbimkdw+rAO3D+tAVs4xPnNHQ/31iw08vWAD53VoyqTBCVzYtTlBQbbYUSAEsgYRgtNJPQLny30pMFFVU3zO6aSqG933lwIPqWqSiHQH3uJkJ/V/gE7WSW2M2XUwlw9WZPHad1vYcSCXhGb1uWVwPFf0jaV+uD37e7a8HOZ6MfAMzjDXmar6mIhMB5JVdY6IPAuMBAqA/cCdxxOIiPwemAQUAveo6rzT3csShDF1S0FRMfPW7GTGt5tZlZFDw4gQrh0Yx03nxtO6sY2COlP2oJwxplZbtnU/M7/dzLw1OxARxvZoyeQhtobFmbDJ+owxtVq/dtH0axdN5v6jvPbdFmb/mMEnP+2gb1xjJg9pz+juLQixNSzOmtUgjDG1zuG8Qt5PzuDV77awde9R2jSux03nteOX/eNoVM/mgvJlTUzGmDqpqFj5z9pdzPh2Mz9s3kdkWDBX9YvllsEJxNszFYAlCGOMYU3WAWYu3szHq7ZTWKyM6NqCyUMSGNS+CSJ1d5isJQhjjHHtPpjLG0u28sYP29h3JJ/EVg2ZNCSBS3u1qpPLp1qCMMaYEnILivhwRRYzF29mw67DxESFc8Ogdlw3MI6mDcK9Dq/KWIIwxpgyqCrfbNzDzMWbWbQ+m7CQIC7v3YZJQxLo0jLK6/ACzoa5GmNMGUSE8zvHcH7nGNJ2H2Lm4i38e3km7yRnMLRTMyYNSWBYp5g6OZ2H1SCMMaaE/UfyeevHbbz23RZ2H8qjQ0x9bhmcwBV9Y6kXVrv6KayJyRhjKiC/sJi5q3cw49vNrM46QOPIUCYOiOPGc+Np2ah2rIBnCcIYYypBVVm6ZT8zvk1nfuougkUYd04rrh/Ujl6xjQkLqblPaVsfhDHGVIKIMCChCQMSmrBt71FmfbeFd5Mz+GjldsJCgujRuiF94qLpE9eYPnHRtG4UUSuerbAahDHGVMCh3AK+3biHFRk5rNi2n58yD5BXWAxA86hw+vokjJ5tGlXbvgurQRhjjJ9FRYQytmcrxvZsBTjTj6/bcYjl2/azYtt+VmTk8FnKTgCCg4RuraJOJo220bRrGlntaxlWgzDGmADZeziPlRk5btLIYVVGDkfynXXPoiNDnWapto3p2y6ac2IbERVR9RMJWg3CGGM80LRBOCO6tWBEtxaAM3ngxt2HWLHNaZZavi2HL9ftBkAEOjePcpulGtM3LpoOMQ08ff4i0CvKjQGexVlR7hVVfaLE8XuBW3FWjcsGJqnqVvdYEbDaPXWbqo4/3b2sBmGMqYkOHCtgVUaOkzQynJrGgWMFAESFh9A7rjF92jp9Gb3bNia6fphf7+/JMFcRCcZZk3oUkImzJvW1qprqc84FwA+qelREfg0MV9VfuscOq2qDM72fJQhjTG2gqqTvOXKilrFiWw7rdh6k2P2qbt+svpM03Oapri2jKrUYkldNTAOANFVNd4OYDUwATiQIVV3oc/4S4PoAxmOMMdWeiNAhpgEdYhpwZb9YAI7kFbI668CJvoyvN2Tz7+VZANQLDWZEt+Y8P7Gv32MJZIJoA2T4bGcCA09z/mRgns92hIgk4zQ/PaGqH/o9QmOMqQHqh4cwqH1TBrVvCji1jMz9x1iRkcPyrfuJDNAQ2mrRSS0i1wNJwDCf3e1UNUtE2gNfishqVd1U4ropwBSAuLi4KovXGGO8JCK0bRJJ2yaRjO/VOmD3CeTz4VlAW5/tWHffKURkJPB7YLyq5h3fr6pZ7s90YBHQp+S1qvqSqiapalJMTIx/ozfGmDoukAliKdBJRBJEJAy4Bpjje4KI9AH+gZMcdvvsjxaRcPd9M2AwPn0XxhhjAi9gTUyqWigidwKf4wxznamqKSIyHUhW1TnAU0AD4D33icLjw1m7Af8QkWKcJPaE7+gnY4wxgWdPUhtjTB12umGuNXeOWmOMMQFlCcIYY0ypLEEYY4wplSUIY4wxpao1ndQikg1srUQRzYA9fgqnprPP4lT2eZzKPo+TasNn0U5VS32QrNYkiMoSkeSyevLrGvssTmWfx6ns8ziptn8W1sRkjDGmVJYgjDHGlMoSxEkveR1ANWKfxans8ziVfR4n1erPwvogjDHGlMpqEMYYY0plCcIYY0yp6nyCEJExIrJeRNJEZJrX8XhJRNqKyEIRSRWRFBG52+uYvCYiwSKyQkQ+8ToWr4lIYxF5X0TWichaETnX65i8JCL/z/3/ZI2IvC0iEV7H5G91OkGISDDwAjAWSASuFZFEb6PyVCFwn6omAoOAO+r45wFwN7DW6yCqiWeBz1S1K9CLOvy5iEgb4C4gSVV74CxpcI23UflfnU4QwAAgTVXTVTUfmA1M8Dgmz6jqDlVd7r4/hPMF0MbbqLwjIrHAOOAVr2Pxmog0As4HZgCoar6q5ngalPdCgHoiEgJEAts9jsfv6nqCaANk+GxnUoe/EH2JSDzOMq8/eByKl54B/hso9jiO6iAByAZedZvcXhGR+l4H5RV3SeT/BbYBO4ADqjrf26j8r64nCFMKEWkA/Au4R1UPeh2PF0TkEmC3qi7zOpZqIgToC7yoqn2AI0Cd7bMTkWic1oYEoDVQX0Su9zYq/6vrCSILaOuzHevuq7NEJBQnObypqv/2Oh4PDQbGi8gWnKbHC0XkDW9D8lQmkKmqx2uU7+MkjLpqJLBZVbNVtQD4N3CexzH5XV1PEEuBTiKSICJhOJ1MczyOyTPiLAw+A1irqk97HY+XVPUBVY1V1Xicfxdfqmqt+wvxTKnqTiBDRLq4u0YAdXmd+G3AIBGJdP+/GUEt7LQP8ToAL6lqoYjcCXyOMwphpqqmeByWlwYDNwCrRWSlu+93qjrXu5BMNfIb4E33j6l04BaP4/GMqv4gIu8Dy3FG/62gFk67YVNtGGOMKVVdb2IyxhhTBksQxhhjSmUJwhhjTKksQRhjjCmVJQhjjDGlsgRhTBlEpEhEVvq8prn7F7kzAK8SkcXHnw0QkTARecadGXijiHzkzud0vLyWIjJbRDaJyDIRmSsinUUkXkTWlLj3wyJyv/t+kIj84MawVkQersKPwdRhdfo5CGPKcUxVe5dx7DpVTRaRKcBTwHjgf4AooIuqFonILcC/RWSge80HwGuqeg2AiPQCWnDqfGCleQ24WlVXuTMQdynnfGP8whKEMZXzNXCPiETiPDiWoKpFAKr6qohMAi4EFChQ1b8fv1BVV8GJiRFPpznOhHC4ZdflJ5hNFbIEYUzZ6vk8UQ7wuKq+U+KcS4HVQEdgWymTGyYD3d33p5v4r0OJe7XEmS0U4K/AehFZBHyGUwvJPdNfwpiKsgRhTNlO18T0pogcA7bgTEERXcl7bfK9l28/g6pOF5E3gYuAicC1wPBK3s+YclmCMKZirlPV5OMbIrIPiBORKHexpeP6AceXK72yojdT1U3AiyLyMpAtIk1VdW9FyzPmTNgoJmP8QFWP4HQmP+12JCMiN+KsNPal+wp3O7Vxj58jIkPLK1tExrkzhgJ0AoqAHP/+Bsb8nCUIY8pWr8Qw1yfKOf8BIBfYICIbgauAy9UFXA6MdIe5pgCPAzvPII4bcPogVgKv49Reiir6Sxlzpmw2V2OMMaWyGoQxxphSWYIwxhhTKksQxhhjSmUJwhhjTKksQRhjjCmVJQhjjDGlsgRhjDGmVP8f61+P9kzCqMsAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training ended\n" + ] + } + ], + "source": [ + "train(train_dataload,test_dataload)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/Project_files/Bengali/bengali_corpus_cleaned.pkl b/Project_files/Bengali/bengali_corpus_cleaned.pkl new file mode 100644 index 0000000..a232ac1 Binary files /dev/null and b/Project_files/Bengali/bengali_corpus_cleaned.pkl differ diff --git a/Project_files/Bengali/epoch_300.pt b/Project_files/Bengali/epoch_300.pt new file mode 100644 index 0000000..89e7c18 Binary files /dev/null and b/Project_files/Bengali/epoch_300.pt differ diff --git a/Project_files/Bengali/index_word.pkl b/Project_files/Bengali/index_word.pkl new file mode 100644 index 0000000..87a3fc2 Binary files /dev/null and b/Project_files/Bengali/index_word.pkl differ diff --git a/Project_files/Bengali/vocab.pkl b/Project_files/Bengali/vocab.pkl new file mode 100644 index 0000000..601c78c Binary files /dev/null and b/Project_files/Bengali/vocab.pkl differ diff --git a/Project_files/Bengali/word_index.pkl b/Project_files/Bengali/word_index.pkl new file mode 100644 index 0000000..82fd475 Binary files /dev/null and b/Project_files/Bengali/word_index.pkl differ diff --git a/Project_files/Hindi/epoch_300.pt b/Project_files/Hindi/epoch_300.pt new file mode 100644 index 0000000..89909c1 Binary files /dev/null and b/Project_files/Hindi/epoch_300.pt differ diff --git a/Project_files/Hindi/hindi_corpus_cleaned.pkl b/Project_files/Hindi/hindi_corpus_cleaned.pkl new file mode 100644 index 0000000..b652dce Binary files /dev/null and b/Project_files/Hindi/hindi_corpus_cleaned.pkl differ diff --git a/Project_files/Hindi/index_word.pkl b/Project_files/Hindi/index_word.pkl new file mode 100644 index 0000000..484816e Binary files /dev/null and b/Project_files/Hindi/index_word.pkl differ diff --git a/Project_files/Hindi/vocab.pkl b/Project_files/Hindi/vocab.pkl new file mode 100644 index 0000000..771245d Binary files /dev/null and b/Project_files/Hindi/vocab.pkl differ diff --git a/Project_files/Hindi/word_index.pkl b/Project_files/Hindi/word_index.pkl new file mode 100644 index 0000000..3b369b8 Binary files /dev/null and b/Project_files/Hindi/word_index.pkl differ diff --git a/Project_files/hindi_model.py b/Project_files/hindi_model.py new file mode 100644 index 0000000..fe9cbac --- /dev/null +++ b/Project_files/hindi_model.py @@ -0,0 +1,240 @@ +import re +import string + +import os +import time +import torch +import pickle +import datetime +import numpy as np +import pandas as pd +from pathlib import Path +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torch.autograd import Variable + +from tqdm.auto import tqdm, trange +from matplotlib import pyplot as plt + +from preprocess import proc_all + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +if device == 'cuda': + print("=================================") + print("GPU found") + print("Using GPU at cuda:",torch.cuda.current_device()) + print("=================================") + print(" ") + +data = pd.read_csv('https://raw.githubusercontent.com/SouravDutta91/NNTI-WS2021-NLP-Project/main/data/hindi_hatespeech.tsv',sep='\t') + +data_text = data['text'] +textd = data_text + +file = "~/NNTI-WS2021-NLP-Project/Project_files/hindi_corpus_cleaned.pkl" + +if Path(file).is_file(): + text = pd.read_pickle(file) + print("Loaded the clean corpus") +else: + text = proc_all(textd) + +text = text['text'] +text = text.apply(lambda x: x.split()) + +print(text) + +V = list(set(text.sum())) #List of unique words in the corpus +all_words = list(text.sum()) #All the words without removing duplicates + +print("Total number of unique words are: ",len(V)) + +#Dictionaries of words and their indexes +word_index = {word: i for i,word in enumerate(V)} +index_word = {i: word for i,word in enumerate(V)} + +def word_to_one_hot(word): + id = V.index(word) + onehot = [0.] * len(V) + onehot[id] = 1. + return torch.tensor(onehot) + +get_onehot = dict((word, word_to_one_hot(word)) for word in V) + +def sampling_prob(word): + if word in all_words: + count = all_words.count(word) + zw_i = count / len(all_words) + p_wi_keep = (np.sqrt(zw_i/0.001) + 1)*(0.001/zw_i) + else: + p_wi_keep = 0 + return p_wi_keep + +def get_target_context(sentence,window): + for i,word in enumerate(sentence): + target = word_index[sentence[i]] + + for j in range(i - window, i + window): + if j!=i and j <= len(sentence)-1 and j>=0: + if sampling_prob(sentence[j]) > thres: + context = word_index[sentence[j]] + yield target,context + +embedding_size = 300 +learning_rate = 0.05 +epochs = 300 +thres = np.random.random() + + +# Create model +class Word2Vec(nn.Module): + def __init__(self): + super().__init__() + self.v_len = len(V) + self.es = embedding_size + self.epochs = epochs + + self.w1 = nn.Linear(len(V),embedding_size,False) + self.w2 = nn.Linear(embedding_size,len(V)) + self.soft = nn.LogSoftmax(dim = 1) + + def forward(self, one_hot): + one_hot = self.w1(one_hot) + one_hot=self.w2(one_hot) + output=self.soft(one_hot) + return output.cuda() + + def softmax(self,input): + output = self.soft(input) + return output + +# Define optimizer and loss +model = Word2Vec().cuda() +#optimizer = optim.SGD(model.parameters(), lr=learning_rate,momentum=0.9,nesterov=True) +optimizer = optim.Adam(model.parameters(), lr=learning_rate) +criterion = nn.NLLLoss() + +print("=====================================") +print("The Word2Vec model: ") +print(model) +print("=====================================") + +'''Gets the corpus and creates the training data with the target and its context +and returns a dataframe containing them in terms of their indexes. +''' +def get_training_data(corpus,window): + t,c = [],[] + for sentence in corpus: + data = get_target_context(sentence,window) + for i,j in data: + x = get_onehot[index_word[i]] + t.append(x) + c.append(j) + t_data = pd.DataFrame(list(zip(t,c)),columns=["target","context"]) + return t_data + +def generateDirectoryName(batchsize,x=0): + path = "m{}_e{}_lr{}_bs{}_es{}".format(len(text),epochs,learning_rate,batchsize,embedding_size) + while True: + dir_name = (path + ('_' + str(x) if x != 0 else '')).strip() + if not os.path.exists(dir_name): + os.mkdir(dir_name) + print("Successfully created a directory at ",dir_name) + return dir_name + else: + x = x + 1 + +def save_model(path,epoch): + torch.save(model.state_dict(),"{}/epoch_{}.pt".format(path,epoch)) + +def save_vocab(path): + with open(path+'/vocab.pkl','wb') as f: + pickle.dump(V,f) + print("Dumped vocabulary successfully") + +def save_wordindex(path): + with open(path+'/word_index.pkl','wb') as f: + pickle.dump(word_index,f) + + with open(path+'/index_word.pkl','wb') as f: + pickle.dump(index_word,f) + print("Dumped word-index dictionaries successfully") + +def save_loss(loss): + with open(path+'/optimizer_loss.pkl','wb') as f: + pickle.dump(loss,f) + + print("Written loss and optimizer successfully") + +def train(traindata,batchsize): + losses = [] + runs = trange(1,epochs+1) + tqdm.write("Training started") + for epoch in runs: + total_loss = [] + for wt,wc in zip(DataLoader(traindata.target.values,batch_size=batchsize), + DataLoader(traindata.context.values,batch_size=batchsize)): + wt = wt.cuda() + wc = wc.cuda() + optimizer.zero_grad() + output = model(wt) + loss = criterion(output,wc) + total_loss.append(loss.item()) + loss.backward() + optimizer.step() + + if epoch % 50 == 0 : + start = time.time() + tqdm.write("===========================================================") + tqdm.write("Saving the model state") + save_model(path,epoch) + end = str(datetime.timedelta(seconds = time.time()-start)) + tqdm.write("Model state saved. It was completed in {}".format(end)) + tqdm.write("===========================================================") + + time.sleep(0.1) + tqdm.write("At epoch {} the loss is ({})".format(epoch ,round(np.mean(total_loss),3))) + losses.append(np.mean(total_loss)) + + plt.xlabel("Epochs") + plt.ylabel("LOSS") + save_loss(losses) + plt.plot(losses) + plt.savefig(path+'/Plot.png') + +# Set hyperparameters +# window_size = 2 defined where the model is being trained +batch_size = 60 +window_size = 2 + +start2 = time.time() +print("=================================================") +print("Collecting training data") +starte = time.time() +data = get_training_data(text,window_size) +ende = str(datetime.timedelta(seconds = time.time()-starte)) +print("It took {} to collect the data".format(ende)) +print("The training data has {} target-context pairs".format(len(data))) +print("===================================================") + +print("Sampling Threshold: ",thres) + +path = generateDirectoryName(batch_size) +train(data,batch_size) + +end2 = str(datetime.timedelta(seconds = time.time()-start2)) +print("Training finished.") +print("It took {} to finish training the model".format(end2)) + +print("===========================================================") +print("Saving Vocabulary and Word-Index dictionaries ") +print("===========================================================") +start1 = time.time() +save_vocab(path) +save_wordindex(path) +end1 = str(datetime.timedelta(seconds = time.time()-start1)) +print("===========================================================") +print("Saved Successfully in {}".format(end1)) +print("===========================================================") diff --git a/Project_files/preprocess.py b/Project_files/preprocess.py new file mode 100644 index 0000000..66dc9d4 --- /dev/null +++ b/Project_files/preprocess.py @@ -0,0 +1,101 @@ +import re +import string + +import os +import time +import torch +import pickle +import datetime +import numpy as np +import pandas as pd + +def punctuations_remove(input): + output = "".join([x for x in input if x not in string.punctuation]) + return output + +def numbers_remove(input): + output = re.sub(r"[0-9]+", "", input) + return output + +def usernames_remove(input): + output = re.sub(r"@\S+", "", input) + return output + +def hashtag_remove(input): + output = re.sub(r"#\S+", "", input) + return output + +def http_remove(input): + output = re.sub(r"http\S+", "", input) + return output + +def emojis_remove(input): + EMOJI_PATTERN = re.compile( + "[" + "\U0001F1E0-\U0001F1FF" # flags (iOS) + "\U0001F300-\U0001F5FF" # symbols & pictographs + "\U0001F600-\U0001F64F" # emoticons + "\U0001F680-\U0001F6FF" # transport & map symbols + "\U0001F700-\U0001F77F" # alchemical symbols + "\U0001F780-\U0001F7FF" # Geometric Shapes Extended + "\U0001F800-\U0001F8FF" # Supplemental Arrows-C + "\U0001F900-\U0001F9FF" # Supplemental Symbols and Pictographs + "\U0001FA00-\U0001FA6F" # Chess Symbols + "\U0001FA70-\U0001FAFF" # Symbols and Pictographs Extended-A + "\U00002702-\U000027B0" # Dingbats + "\U000024C2-\U0001F251" + "]+" + ) + + output = EMOJI_PATTERN.sub(r'',input) + return output + +def extra_whitespaces(input): + output = input.replace('\s+', ' ') + return output + +def stopwords_remove(m): + hindi_stopwords = pd.read_csv('https://raw.githubusercontent.com/stopwords-iso/stopwords-hi/master/stopwords-hi.txt').stack().tolist() + english_stopwords = pd.read_csv('https://raw.githubusercontent.com/stopwords-iso/stopwords-en/master/stopwords-en.txt').stack().tolist() + stopwords = hindi_stopwords + english_stopwords + + output = pd.Series(m).apply(lambda x: [item for item in x.split() if item not in stopwords]) + return output + +def tolower(input): + output = input.lower() + return output + +def corpus_preprocess(corpus): + corpus = corpus.apply(lambda x: tolower(x)) + corpus = corpus.apply(lambda x: emojis_remove(x)) + corpus = corpus.apply(lambda x: http_remove(x)) + corpus = corpus.apply(lambda x: hashtag_remove(x)) + corpus = corpus.apply(lambda x: numbers_remove(x)) + corpus = corpus.apply(lambda x: usernames_remove(x)) + corpus = corpus.apply(lambda x: punctuations_remove(x)) + corpus = corpus.apply(lambda x: stopwords_remove(x)) + corpus = corpus.apply(lambda x: extra_whitespaces(x)) + return corpus + + +def proc_all(text): + #text = corpus_preprocess(text) + print("Started Preprocessing") + cleanstart = time.time() + text = corpus_preprocess(text) + cleanend = str(datetime.timedelta(seconds = time.time()-cleanstart)) + print("Preprocessing ended!") + print("Pre-processing the text took {}".format(cleanend)) + print("===========================================================") + print("-------") + + c = [] + for sent in text[0]: + a = " ".join(sent) + c.append(a) + d = pd.DataFrame(c,columns=["text"]) + with open('hindi_corpus_cleaned.pkl','wb') as f: + pickle.dump(d,f) + return d + diff --git a/README.md b/README.md index 47def7d..a5a01b5 100644 --- a/README.md +++ b/README.md @@ -1,65 +1,22 @@ # NNTI Final Project (Sentiment Analysis & Transfer Learning) NNTI (WS-2021), Saarland University, Germany -## Introduction -This is a final project for the course **Neural Networks: Theory and Implementation (NNTI)**. This project will introduce you to Sentiment Classification and Analysis. *Sentiment analysis* (also known as *opinion mining* or *emotion AI*) refers to the use of natural language processing, text analysis, computational linguistics, and/or biometrics to systematically identify, extract, quantify, and study affective states and subjective information. *Transfer learning* is a machine learning research problem focusing on storing knowledge gained while solving one problem and applying it to a different but related problem. In this project, we want you to create a neural sentiment classifier completely from scratch. You first train it on one type of dataset and then apply it to another related but different dataset. You are expected to make use of concepts that you have learnt in the lecture. The project is divided into three tasks, the details of which you can find below. - -## Repository -We have created this Github repository for this project. You will need to - -* fork the repository into your own Github account. -* update your forked repository with your code and solutions. -* submit the report and the link to your public repository for the available code. - -## Distribution of Points -The points in this project are equally distributed among the three tasks. You will able to scorea maximum of 10 points per task, resulting to a total of 30 points in the entire project. How the 10 points are allotted for each task, is mentioned in the guidelines. - -## Task 1: Word Embeddings -Neural networks operate on numerical data and not on string or characters. In order to train a neural network or even any machine learning model on text data, we first need to convert the text data in some form of numerical representation before feeding it to the model. There are obviously multiple ways to do this, some of which you have come across during the course of this lecture, like the one-hot encoding method. However, traditional methods like one-hot encoding were eventually replaced by neural Word Embeddings like Word2Vec [[1, 2](#references)] and GloVe [[3](#references)]. A *word embedding* or *word vector* is a vector representation of an input word that captures the meaning of that word in a semantic vector space. You can find a video lecture from Stanford about Word2Vec here for better understanding. For this task, you are expected to create your own word embeddings from scratch. You are supposed to use the HASOC Hindi [[4](#references)] sentiment dataset and train a neural network to extract word embeddings for the data. The unfinished code for this task is already in place inthe corresponding Jupyter notebook which you can find in the repository. -* Follow the instructions in the notebook, complete the code, and run it -* Save the trained word embeddings -* Update your repository with the completed notebook -## Task 2: Sentiment Classifier & Transfer Learning -In this task you are expected to reproduce ***Subtask A*** from the HASOC paper [[4](#references)] using the Hindi word embeddings from Task 1. Then, you will apply your knowledge of transfer learning by using your model from Task 1 to train Bengali word embeddings and then use the trained classifier to predict hate speech on this Bengali data set. The data is already included in the repository. -You are expected to read some related research work (for example, encoder-decoder architecture, attention mechanism, etc.) in neural sentiment analysis and then create an end-to-end neural network architecture for the task. After training, you should report the accuracy score of the model on test data. Follow the steps below: - -* **Binary neural sentiment classifier:** Implement a binary neural sentiment classifier for the Hindi section of the corpus. Use your word embeddings from Task 1 for that. Report the accuracy score. -* **Preprocess the Bengali data:** Split off a part of the Bengali corpus such that it roughly equals the Hindi corpus in size and distribution of classes (hatespeech/non-hatespeech). Then, apply the preprocessing pipeline from Task 1 to the new data. You can deviate from the pipeline, but should justify your decision. -* **Bengali word embeddings:** Use the model you created in Task 1 to create Bengali word embeddings. -* **Apply** classifier to Bengali data, and report accuracy. Retrain your model with the Bengali data. Report the new accuracy and justify your findings. - -## Task 3: Challenge Task -In this third and final task of this project, you are expected to - - -* Read multiple resources about the state-of-the-art work related to sentiment classification and analysis -* Try to come up with methodologies that would possibly improve your existing results -* Improve on the 3 accuracy scores from Task 2 +## Introduction +Neural networks can be applied to many NLP tasks, such as text classification. In this report, ourgoal is to use a neural network architecture to correctly predict hate speech on the given dataset. Thereport is broadly composed of three main parts.Firstly, since neural networks only operate on numerical data and not on string or character, we needto convert our text data in some form of numerical representation before feeding it to the model.In contrast to traditional NLP approaches which associate words with discrete representations likeone-hot encoding method, we use word embeddings to represent words by dense, low-dimensionaland real-valued vectors.Then we use the word embeddings we get to make a binary neural sentiment classifier for the Hindidataset with a LSTM model. In this classifier, we are supposed to classify texts into two class, namely:Hate and Offensive(HOF) and Non-Hate and offensive(NOT). And then we apply this classifier toBengali dataset using the knowledge of transfer learning.Finally, we try to make improvement on our accuracy results by changing our model architecture intoconvolutional neural network(CNN) model. -Note: The task here should be a change in the model architecture, data representation, different approach, or some other similar considerable change in your process pipeline. Please note that although you should consider fine-tuning the model hyperparameters manually, just doing that does not count as a change here. -## General Guidelines -* You are not allowed to use ready-made libraries like Hugging Face. -* You are allowed to use convenience methods for data loading & preprocessing from packages like *scipy*, *numpy*, *pandas*, *sklearn*, *NLTK*. If you use other packages, provide areference and justify it. -* Plagiarism will be penalized and can eventually lead to disqualification from the project and the course. Most importantly, we will check for plagiarism within groups. If we see any clear indication of plagiarism among groups, both the groups will be awarded 0 for the whole project. Discussion with groups is allowed (only in terms of concepts but not directly with code). -* Cite any resources that you found to be helpful. -* You are expected to provide a separate notebook for each task. However, the notebook should only contain runtime code. Functions or classes you write should be in separate python scripts (.py) that are imported into the notebook of that task. -* Your code should be sufficiently commented so that we can grade it easily. Not providing proper documentation can lead to your code not being graded. -* Write a well-documented academic report. The report needs to be 4-8 pages long following the NIPS format. You can have a look at Latex versions or in other formats. We expect from you a .pdf file from you. The way how you divide it is up to you but we roughly expect to have introduction, methodology, results, and conclusion sections. Of course you will have to cite every source that you use. -* The main focus of our grading will be your observations and analysis of the results. Even though you might obtain bad results make comments on what could have gone wrong. +### Files +Task 1 - Task1_Word_Embeddings.ipynb +Task 2 - LSTM-classifier-Task-2.ipynb +Task 3 - CNN-task-3.ipynb -## Submission instructions -* You are required to submit the final project as a team of two students. -* You should submit a detailed report and implementation in a zip file. Link to the repository should also suffice. -* Make sure to write the MS Teams username, matriculation number, and the name of eachmember of your team on your submission. -* If you have any trouble with the submission, contact the tutors before the deadline. +##### Authors +Nishant Gajjar - 2577584 +[s8nigajj@stud.uni-saarland.de](s8nigajj@stud.uni-saarland.de) -## Contact -In case you encounter any problems or have any doubts regarding this project, please feel free to contact the tutors in charge (*Julius Dietmar Steuer*, *Sourav Dutta*) on MS Teams. +Zhifei Li - 7010552 +[zhli00001@stud.uni-saarland.de](zhli00001@stud.uni-saarland.de) -## References -1. Tomas Mikolov, Ilya Sutskever, Kai Chen, Greg S Corrado, and Jeff Dean. Distributed representations of words and phrases and their compositionality. *Advances in neural information processing systems*, 26:3111–3119, 2013. -2. Tomas Mikolov, Kai Chen, Greg Corrado, and Jeffrey Dean. Efficient estimation of word representations in vector space. *arXiv preprint arXiv:1301.3781*, 2013. -3. Jeffrey Pennington, Richard Socher, and Christopher D Manning. Glove: Global vectors forword representation. *In Proceedings of the 2014 conference on empirical methods in natural language processing (EMNLP)*, pages 1532–1543, 2014. -4. Thomas Mandl, Sandip Modha, Prasenjit Majumder, Daksh Patel, Mohana Dave, Chintak Mandlia, and Aditya Patel. Overview of the hasoc track at fire 2019: Hate speech and offensive content identification in indo-european languages. *In Proceedings of the 11th Forum for Information Retrieval Evaluation*, pages 14–17, 2019. diff --git a/Task1_Word_Embeddings.ipynb b/Task1_Word_Embeddings.ipynb index 3cfd498..1d3be3a 100644 --- a/Task1_Word_Embeddings.ipynb +++ b/Task1_Word_Embeddings.ipynb @@ -1,388 +1,1394 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "colab": { - "name": "NNTI_final_project_task_1.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "accelerator": "GPU" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "_VZXi_KGi0UR" + }, + "source": [ + "# Task 1: Word Embeddings (10 points)\n", + "\n", + "This notebook will guide you through all steps necessary to train a word2vec model (Detailed description in the PDF)." + ] }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "_VZXi_KGi0UR" - }, - "source": [ - "# Task 1: Word Embeddings (10 points)\r\n", - "\r\n", - "This notebook will guide you through all steps necessary to train a word2vec model (Detailed description in the PDF)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "48t-II1vkuau" - }, - "source": [ - "## Imports\r\n", - "\r\n", - "This code block is reserved for your imports. \r\n", - "\r\n", - "You are free to use the following packages: \r\n", - "\r\n", - "(List of packages)" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "4kh6nh84-AOL" - }, - "source": [ - "# Imports" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NWmk3hVllEcU" - }, - "source": [ - "# 1.1 Get the data (0.5 points)\r\n", - "\r\n", - "The Hindi portion HASOC corpus from [github.io](https://hasocfire.github.io/hasoc/2019/dataset.html) is already available in the repo, at data/hindi_hatespeech.tsv . Load it into a data structure of your choice. Then, split off a small part of the corpus as a development set (~100 data points).\r\n", - "\r\n", - "If you are using Colab the first two lines will let you upload folders or files from your local file system." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "XtI7DJ-0-AOP" - }, - "source": [ - "#TODO: implement!\n", - "\n", - "#from google.colab import files\n", - "#uploaded = files.upload()\n", - "\n", - "data =" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "D-mSJ8nUlupB" - }, - "source": [ - "## 1.2 Data preparation (0.5 + 0.5 points)\r\n", - "\r\n", - "* Prepare the data by removing everything that does not contain information. \r\n", - "User names (starting with '@') and punctuation symbols clearly do not convey information, but we also want to get rid of so-called [stopwords](https://en.wikipedia.org/wiki/Stop_word), i. e. words that have little to no semantic content (and, but, yes, the...). Hindi stopwords can be found [here](https://github.com/stopwords-iso/stopwords-hi/blob/master/stopwords-hi.txt) Then, standardize the spelling by lowercasing all words.\r\n", - "Do this for the development section of the corpus for now.\r\n", - "\r\n", - "* What about hashtags (starting with '#') and emojis? Should they be removed too? Justify your answer in the report, and explain how you accounted for this in your implementation." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "CHcNeyKi-AOQ" - }, - "source": [ - "#TODO: implement!" - ], - "execution_count": null, - "outputs": [] - }, + { + "cell_type": "markdown", + "metadata": { + "id": "48t-II1vkuau" + }, + "source": [ + "## Imports\n", + "\n", + "This code block is reserved for your imports. \n", + "\n", + "You are free to use the following packages: \n", + "\n", + "(List of packages)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "4kh6nh84-AOL" + }, + "outputs": [], + "source": [ + "import re,string\n", + "\n", + "import os\n", + "import time\n", + "import torch\n", + "import pickle\n", + "import datetime\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from pathlib import Path\n", + "\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader\n", + "from torch.autograd import Variable\n", + "\n", + "from tqdm.auto import tqdm, trange\n", + "from matplotlib import pyplot as plt\n", + "\n", + "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]='2,3,4,5'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "Je09nozLmmMm" - }, - "source": [ - "## 1.3 Build the vocabulary (0.5 + 0.5 points)\r\n", - "\r\n", - "The input to the first layer of word2vec is an one-hot encoding of the current word. The output od the model is then compared to a numeric class label of the words within the size of the skip-gram window. Now\r\n", - "\r\n", - "* Compile a list of all words in the development section of your corpus and save it in a variable ```V```." - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "=================================\n", + "GPU found\n", + "Using GPU at cuda: 0\n", + "=================================\n", + " \n" + ] + } + ], + "source": [ + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "if device == 'cuda':\n", + " print(\"=================================\")\n", + " print(\"GPU found\")\n", + " print(\"Using GPU at cuda:\",torch.cuda.current_device())\n", + " print(\"=================================\")\n", + " print(\" \")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NWmk3hVllEcU" + }, + "source": [ + "# 1.1 Get the data (0.5 points)\n", + "\n", + "The Hindi portion HASOC corpus from [github.io](https://hasocfire.github.io/hasoc/2019/dataset.html) is already available in the repo, at data/hindi_hatespeech.tsv . Load it into a data structure of your choice. Then, split off a small part of the corpus as a development set (~100 data points).\n", + "\n", + "If you are using Colab the first two lines will let you upload folders or files from your local file system." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "id": "XtI7DJ-0-AOP" + }, + "outputs": [], + "source": [ + "data = pd.read_csv('https://raw.githubusercontent.com/SouravDutta91/NNTI-WS2021-NLP-Project/main/data/hindi_hatespeech.tsv',sep='\\t')\n", + "data_text = data['text']\n", + "textd = data_text[:100]" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "VpoGmTKx-AOQ" - }, - "source": [ - "#TODO: implement!\n", - "V = " + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
text_idtexttask_1task_2task_3
0hasoc_hi_5556बांग्लादेश की शानदार वापसी, भारत को 314 रन पर ...NOTNONENONE
1hasoc_hi_5648सब रंडी नाच देखने मे व्यस्त जैसे ही कोई #शांती...HOFPRFNUNT
2hasoc_hi_164तुम जैसे हरामियों के लिए बस जूतों की कमी है शु...HOFPRFNTIN
3hasoc_hi_3530बीजेपी MLA आकाश विजयवर्गीय जेल से रिहा, जमानत ...NOTNONENONE
4hasoc_hi_5206चमकी बुखार: विधानसभा परिसर में आरजेडी का प्रदर...NOTNONENONE
5hasoc_hi_5121मुंबई में बारिश से लोगों को काफी समस्या हो रही हैNOTNONENONE
6hasoc_hi_7142Ahmed's dad:-- beta aaj teri mammy kyu nahi ba...NOTNONENONE
7hasoc_hi_43215 लाख मुसलमान उर्स में, अजमेर की दरगाह पर आते ...NOTNONENONE
8hasoc_hi_4674Do mahashaktiyan mili hain, charo taraf khusi ...NOTNONENONE
9hasoc_hi_1637Chants of 'Jai Sri Ram' as Owaisi takes oath: ...NOTNONENONE
\n", + "
" ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WiaVglVNoENY" - }, - "source": [ - "* Then, write a function ```word_to_one_hot``` that returns a one-hot encoding of an arbitrary word in the vocabulary. The size of the one-hot encoding should be ```len(v)```." + "text/plain": [ + " text_id text task_1 \\\n", + "0 hasoc_hi_5556 बांग्लादेश की शानदार वापसी, भारत को 314 रन पर ... NOT \n", + "1 hasoc_hi_5648 सब रंडी नाच देखने मे व्यस्त जैसे ही कोई #शांती... HOF \n", + "2 hasoc_hi_164 तुम जैसे हरामियों के लिए बस जूतों की कमी है शु... HOF \n", + "3 hasoc_hi_3530 बीजेपी MLA आकाश विजयवर्गीय जेल से रिहा, जमानत ... NOT \n", + "4 hasoc_hi_5206 चमकी बुखार: विधानसभा परिसर में आरजेडी का प्रदर... NOT \n", + "5 hasoc_hi_5121 मुंबई में बारिश से लोगों को काफी समस्या हो रही है NOT \n", + "6 hasoc_hi_7142 Ahmed's dad:-- beta aaj teri mammy kyu nahi ba... NOT \n", + "7 hasoc_hi_4321 5 लाख मुसलमान उर्स में, अजमेर की दरगाह पर आते ... NOT \n", + "8 hasoc_hi_4674 Do mahashaktiyan mili hain, charo taraf khusi ... NOT \n", + "9 hasoc_hi_1637 Chants of 'Jai Sri Ram' as Owaisi takes oath: ... NOT \n", + "\n", + " task_2 task_3 \n", + "0 NONE NONE \n", + "1 PRFN UNT \n", + "2 PRFN TIN \n", + "3 NONE NONE \n", + "4 NONE NONE \n", + "5 NONE NONE \n", + "6 NONE NONE \n", + "7 NONE NONE \n", + "8 NONE NONE \n", + "9 NONE NONE " ] - }, + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D-mSJ8nUlupB" + }, + "source": [ + "## 1.2 Data preparation (0.5 + 0.5 points)\n", + "\n", + "* Prepare the data by removing everything that does not contain information. \n", + "User names (starting with '@') and punctuation symbols clearly do not convey information, but we also want to get rid of so-called [stopwords](https://en.wikipedia.org/wiki/Stop_word), i. e. words that have little to no semantic content (and, but, yes, the...). Hindi stopwords can be found [here](https://github.com/stopwords-iso/stopwords-hi/blob/master/stopwords-hi.txt) Then, standardize the spelling by lowercasing all words.\n", + "Do this for the development section of the corpus for now.\n", + "\n", + "* What about hashtags (starting with '#') and emojis? Should they be removed too? Justify your answer in the report, and explain how you accounted for this in your implementation." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "def punctuations_remove(input):\n", + " output = \"\".join([x for x in input if x not in string.punctuation])\n", + " return output\n", + "\n", + "def numbers_remove(input):\n", + " output = re.sub(r\"[0-9]+\", \"\", input)\n", + " return output\n", + "\n", + "def usernames_remove(input):\n", + " output = re.sub(r\"@\\S+\", \"\", input)\n", + " return output\n", + "\n", + "def hashtag_remove(input):\n", + " output = re.sub(r\"#\\S+\", \"\", input)\n", + " return output\n", + "\n", + "def http_remove(input):\n", + " output = re.sub(r\"http\\S+\", \"\", input)\n", + " return output\n", + "\n", + "def emojis_remove(input):\n", + " EMOJI_PATTERN = re.compile(\n", + " \"[\"\n", + " \"\\U0001F1E0-\\U0001F1FF\" # flags (iOS)\n", + " \"\\U0001F300-\\U0001F5FF\" # symbols & pictographs\n", + " \"\\U0001F600-\\U0001F64F\" # emoticons\n", + " \"\\U0001F680-\\U0001F6FF\" # transport & map symbols\n", + " \"\\U0001F700-\\U0001F77F\" # alchemical symbols\n", + " \"\\U0001F780-\\U0001F7FF\" # Geometric Shapes Extended\n", + " \"\\U0001F800-\\U0001F8FF\" # Supplemental Arrows-C\n", + " \"\\U0001F900-\\U0001F9FF\" # Supplemental Symbols and Pictographs\n", + " \"\\U0001FA00-\\U0001FA6F\" # Chess Symbols\n", + " \"\\U0001FA70-\\U0001FAFF\" # Symbols and Pictographs Extended-A\n", + " \"\\U00002702-\\U000027B0\" # Dingbats\n", + " \"\\U000024C2-\\U0001F251\" \n", + " \"]+\"\n", + " )\n", + " \n", + " output = EMOJI_PATTERN.sub(r'',input)\n", + " return output\n", + "\n", + "def extra_whitespaces(input):\n", + " output = input.replace('\\s+', ' ')\n", + " return output\n", + "\n", + "def stopwords_remove(m):\n", + " hindi_stopwords = pd.read_csv('https://raw.githubusercontent.com/stopwords-iso/stopwords-hi/master/stopwords-hi.txt').stack().tolist()\n", + " english_stopwords = pd.read_csv('https://raw.githubusercontent.com/stopwords-iso/stopwords-en/master/stopwords-en.txt').stack().tolist()\n", + " stopwords = hindi_stopwords + english_stopwords\n", + "\n", + " output = pd.Series(m).apply(lambda x: [item for item in x.split() if item not in stopwords])\n", + " return output\n", + "\n", + "def tolower(input):\n", + " output = input.lower()\n", + " return output\n", + "\n", + "def corpus_preprocess(corpus):\n", + " corpus = corpus.apply(lambda x: tolower(x))\n", + " corpus = corpus.apply(lambda x: emojis_remove(x))\n", + " corpus = corpus.apply(lambda x: http_remove(x))\n", + " corpus = corpus.apply(lambda x: hashtag_remove(x))\n", + " corpus = corpus.apply(lambda x: numbers_remove(x))\n", + " corpus = corpus.apply(lambda x: usernames_remove(x))\n", + " corpus = corpus.apply(lambda x: punctuations_remove(x))\n", + " corpus = corpus.apply(lambda x: stopwords_remove(x))\n", + " corpus = corpus.apply(lambda x: extra_whitespaces(x))\n", + " return corpus" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": { + "scrolled": false + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "yqPNw6IT-AOQ" - }, - "source": [ - "#TODO: implement!\n", - "def word_to_one_hot(word):\n", - " pass" - ], - "execution_count": null, - "outputs": [] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Started Preprocessing\n", + "Preprocessing ended!\n", + "Pre-processing the text took 0:00:08.811462\n", + "===========================================================\n" + ] + } + ], + "source": [ + "print(\"Started Preprocessing\")\n", + "cleanstart = time.time()\n", + "cleantext = corpus_preprocess(textd)\n", + "cleanend = str(datetime.timedelta(seconds = time.time()-cleanstart))\n", + "print(\"Preprocessing ended!\")\n", + "print(\"Pre-processing the text took {}\".format(cleanend))\n", + "print(\"===========================================================\")" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "def tokens(text):\n", + " c = []\n", + " for sent in text[0]:\n", + " a = \" \".join(sent)\n", + " c.append(a)\n", + " df_text = pd.DataFrame(c,columns=[\"text\"])\n", + " return df_text" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "text = tokens(cleantext)\n", + "text = text['text'].apply(lambda x: x.split())" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "gKD8zBlxVclh" - }, - "source": [ - "## 1.4 Subsampling (0.5 points)\r\n", - "\r\n", - "The probability to keep a word in a context is given by:\r\n", - "\r\n", - "$P_{keep}(w_i) = \\Big(\\sqrt{\\frac{z(w_i)}{0.001}}+1\\Big) \\cdot \\frac{0.001}{z(w_i)}$\r\n", - "\r\n", - "Where $z(w_i)$ is the relative frequency of the word $w_i$ in the corpus. Now,\r\n", - "* Calculate word frequencies\r\n", - "* Define a function ```sampling_prob``` that takes a word (string) as input and returns the probabiliy to **keep** the word in a context." + "data": { + "text/plain": [ + "0 [बांग्लादेश, शानदार, वापसी, भारत, रन, रोका]\n", + "1 [सब, रंडी, नाच, देखने, व्यस्त, होगा, सब, शुरू,...\n", + "2 [तुम, हरामियों, बस, जूतों, कमी, शुक्र, तुम्हार...\n", + "3 [बीजेपी, mla, आकाश, विजयवर्गीय, जेल, रिहा, जमा...\n", + "4 [चमकी, बुखार, विधानसभा, परिसर, आरजेडी, प्रदर्श...\n", + " ... \n", + "95 [देश, पहली, बार, सरकार, प्रो, इंकम्बेंसी, जनाद...\n", + "96 [आदमी, आदमी, मैं, पानी, बारे, सोचता, थालिखने, ...\n", + "97 [मादरजात, सनी, तेरे, पास, टाइम, नही, तोतेरी, म...\n", + "98 [थोर, क्रांतिकारक, राणी, लक्ष्मीबाई, यांना, पु...\n", + "99 [मुस्लिम, लोगों, वोट, मांगने, वाली, पार्टियां,...\n", + "Name: text, Length: 100, dtype: object" ] - }, + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Je09nozLmmMm" + }, + "source": [ + "## 1.3 Build the vocabulary (0.5 + 0.5 points)\n", + "\n", + "The input to the first layer of word2vec is an one-hot encoding of the current word. The output od the model is then compared to a numeric class label of the words within the size of the skip-gram window. Now\n", + "\n", + "* Compile a list of all words in the development section of your corpus and save it in a variable ```V```." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "VpoGmTKx-AOQ" + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "Mj4sDOVMMr0b" - }, - "source": [ - "#TODO: implement!\r\n", - "def sampling_prob(word):\r\n", - " pass" - ], - "execution_count": null, - "outputs": [] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of unique words are: 1097\n" + ] + } + ], + "source": [ + "V = list(set(text.sum())) #List of unique words in the corpus\n", + "all_words = list(text.sum()) #All the words without removing duplicates\n", + "print(\"Total number of unique words are: \",len(V))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "kxV1P90zplxu" - }, - "source": [ - "# 1.5 Skip-Grams (1 point)\r\n", - "\r\n", - "Now that you have the vocabulary and one-hot encodings at hand, you can start to do the actual work. The skip gram model requires training data of the shape ```(current_word, context)```, with ```context``` being the words before and/or after ```current_word``` within ```window_size```. \r\n", - "\r\n", - "* Have closer look on the original paper. If you feel to understand how skip-gram works, implement a function ```get_target_context``` that takes a sentence as input and [yield](https://docs.python.org/3.9/reference/simple_stmts.html#the-yield-statement)s a ```(current_word, context)```.\r\n", - "\r\n", - "* Use your ```sampling_prob``` function to drop words from contexts as you sample them. " + "data": { + "text/plain": [ + "['स्पेशली',\n", + " 'दी।',\n", + " 'hai',\n", + " 'पहचान',\n", + " 'बताया',\n", + " 'khus',\n", + " 'साहिब',\n", + " 'जाती',\n", + " 'नया',\n", + " 'श्रद्धांजलि',\n", + " 'जनता',\n", + " 'कप',\n", + " 'सवाल',\n", + " 'तैसी',\n", + " 'wale',\n", + " 'हैँ',\n", + " 'भारतीय',\n", + " 'समर्थक',\n", + " 'शत्शत्',\n", + " 'तुमहृदय']" ] - }, - { - "cell_type": "code", - "metadata": { - "id": "r8CCTpVy-AOR" - }, - "source": [ - "#TODO: implement!\n", - "\n", - "def get_target_context(sentence):\n", - " pass" - ], - "execution_count": null, - "outputs": [] - }, + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V[:20]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "#Dictionaries of words and their indexes\n", + "word_index = {word: i for i,word in enumerate(V)}\n", + "index_word = {i: word for i,word in enumerate(V)}" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "scrolled": true + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "gfEFgtkmuDjL" - }, - "source": [ - "# 1.6 Hyperparameters (0.5 points)\r\n", - "\r\n", - "According to the word2vec paper, what would be a good choice for the following hyperparameters? \r\n", - "\r\n", - "* Embedding dimension\r\n", - "* Window size\r\n", - "\r\n", - "Initialize them in a dictionary or as independent variables in the code block below. " + "data": { + "text/plain": [ + "152" ] - }, - { - "cell_type": "code", - "metadata": { - "id": "d7xSKuFJcYoD" - }, - "source": [ - "# Set hyperparameters\n", - "window_size = \n", - "embedding_size = \n", - "\n", - "# More hyperparameters\n", - "learning_rate = 0.05\n", - "epochs = 100" - ], - "execution_count": null, - "outputs": [] - }, + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "word_index['मुखिया']" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "xiM2zq-YunPx" - }, - "source": [ - "# 1.7 Pytorch Module (0.5 + 0.5 + 0.5 points)\r\n", - "\r\n", - "Pytorch provides a wrapper for your fancy and super-complex models: [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). The code block below contains a skeleton for such a wrapper. Now,\r\n", - "\r\n", - "* Initialize the two weight matrices of word2vec as fields of the class.\r\n", - "\r\n", - "* Override the ```forward``` method of this class. It should take a one-hot encoding as input, perform the matrix multiplications, and finally apply a log softmax on the output layer.\r\n", - "\r\n", - "* Initialize the model and save its weights in a variable. The Pytorch documentation will tell you how to do that." + "data": { + "text/plain": [ + "'जाती'" ] - }, + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index_word[7]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WiaVglVNoENY" + }, + "source": [ + "* Then, write a function ```word_to_one_hot``` that returns a one-hot encoding of an arbitrary word in the vocabulary. The size of the one-hot encoding should be ```len(v)```." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "yqPNw6IT-AOQ" + }, + "outputs": [], + "source": [ + "def word_to_one_hot(word):\n", + " id = V.index(word)\n", + " onehot = [0.] * len(V)\n", + " onehot[id] = 1.\n", + " return torch.tensor(onehot)\n", + "\n", + "get_onehot = dict((word, word_to_one_hot(word)) for word in V)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "D9sGNytYhwxS", - "outputId": "41645b64-e4ed-4e6a-e10f-74cb39b92230" - }, - "source": [ - "# Create model \n", - "\n", - "class Word2Vec(Module):\n", - " def __init__(self):\n", - " super().__init__()\n", - "\n", - "\n", - " def forward(self, one_hot):\n", - " pass" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Word2Vec(\n", - " (input): Linear(in_features=534, out_features=300, bias=False)\n", - " (output): Linear(in_features=300, out_features=534, bias=False)\n", - ")\n" - ], - "name": "stdout" - } + "data": { + "text/plain": [ + "tensor([0., 0., 0., ..., 0., 0., 0.])" ] - }, + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_onehot['मुखिया']" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gKD8zBlxVclh" + }, + "source": [ + "## 1.4 Subsampling (0.5 points)\n", + "\n", + "The probability to keep a word in a context is given by:\n", + "\n", + "$P_{keep}(w_i) = \\Big(\\sqrt{\\frac{z(w_i)}{0.001}}+1\\Big) \\cdot \\frac{0.001}{z(w_i)}$\n", + "\n", + "Where $z(w_i)$ is the relative frequency of the word $w_i$ in the corpus. Now,\n", + "* Calculate word frequencies\n", + "* Define a function ```sampling_prob``` that takes a word (string) as input and returns the probabiliy to **keep** the word in a context." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "Mj4sDOVMMr0b" + }, + "outputs": [], + "source": [ + "def sampling_prob(word):\n", + " if word in all_words:\n", + " count = all_words.count(word)\n", + " zw_i = count / len(all_words)\n", + " p_wi_keep = (np.sqrt(zw_i/0.001) + 1)*(0.001/zw_i)\n", + " else:\n", + " p_wi_keep = 0\n", + " return p_wi_keep" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "XefIDMMHv5zJ" - }, - "source": [ - "# 1.8 Loss function and optimizer (0.5 points)\r\n", - "\r\n", - "Initialize variables with [optimizer](https://pytorch.org/docs/stable/optim.html#module-torch.optim) and loss function. You can take what is used in the word2vec paper, but you can use alternative optimizers/loss functions if you explain your choice in the report." + "data": { + "text/plain": [ + "2.7641229712289954" ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sampling_prob('मुखिया')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kxV1P90zplxu" + }, + "source": [ + "# 1.5 Skip-Grams (1 point)\n", + "\n", + "Now that you have the vocabulary and one-hot encodings at hand, you can start to do the actual work. The skip gram model requires training data of the shape ```(current_word, context)```, with ```context``` being the words before and/or after ```current_word``` within ```window_size```. \n", + "\n", + "* Have closer look on the original paper. If you feel to understand how skip-gram works, implement a function ```get_target_context``` that takes a sentence as input and [yield](https://docs.python.org/3.9/reference/simple_stmts.html#the-yield-statement)s a ```(current_word, context)```.\n", + "\n", + "* Use your ```sampling_prob``` function to drop words from contexts as you sample them. " + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "r8CCTpVy-AOR" + }, + "outputs": [], + "source": [ + "def get_target_context(sentence,window):\n", + " thres = np.random.random()\n", + "\n", + " for i,word in enumerate(sentence):\n", + " target = word_index[sentence[i]]\n", + "\n", + " for j in range(i - window, i + window):\n", + " if j!=i and j <= len(sentence)-1 and j>=0:\n", + " if sampling_prob(sentence[j]) > thres:\n", + " context = word_index[sentence[j]]\n", + " yield target,context" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gfEFgtkmuDjL" + }, + "source": [ + "# 1.6 Hyperparameters (0.5 points)\n", + "\n", + "According to the word2vec paper, what would be a good choice for the following hyperparameters? \n", + "\n", + "* Embedding dimension\n", + "* Window size\n", + "\n", + "Initialize them in a dictionary or as independent variables in the code block below. " + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "d7xSKuFJcYoD" + }, + "outputs": [], + "source": [ + "# Set hyperparameters\n", + "window_size = 2\n", + "embedding_size = 300\n", + "\n", + "# More hyperparameters\n", + "learning_rate = 0.05\n", + "epochs = 300\n", + "batch_size = 60" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xiM2zq-YunPx" + }, + "source": [ + "# 1.7 Pytorch Module (0.5 + 0.5 + 0.5 points)\n", + "\n", + "Pytorch provides a wrapper for your fancy and super-complex models: [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). The code block below contains a skeleton for such a wrapper. Now,\n", + "\n", + "* Initialize the two weight matrices of word2vec as fields of the class.\n", + "\n", + "* Override the ```forward``` method of this class. It should take a one-hot encoding as input, perform the matrix multiplications, and finally apply a log softmax on the output layer.\n", + "\n", + "* Initialize the model and save its weights in a variable. The Pytorch documentation will tell you how to do that." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "D9sGNytYhwxS", + "outputId": "41645b64-e4ed-4e6a-e10f-74cb39b92230" + }, + "outputs": [], + "source": [ + "# Create model \n", + "class Word2Vec(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.v_len = len(V)\n", + " self.es = embedding_size\n", + " self.epochs = epochs\n", + " \n", + " self.w1 = nn.Linear(len(V),embedding_size,False)\n", + " self.w2 = nn.Linear(embedding_size,len(V))\n", + " self.soft = nn.LogSoftmax(dim = 1)\n", + "\n", + " def forward(self, one_hot):\n", + " one_hot = self.w1(one_hot)\n", + " one_hot=self.w2(one_hot)\n", + " output=self.soft(one_hot)\n", + " return output.cuda()\n", + "\n", + " def softmax(self,input): \n", + " output = self.soft(input)\n", + " return output" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XefIDMMHv5zJ" + }, + "source": [ + "# 1.8 Loss function and optimizer (0.5 points)\n", + "\n", + "Initialize variables with [optimizer](https://pytorch.org/docs/stable/optim.html#module-torch.optim) and loss function. You can take what is used in the word2vec paper, but you can use alternative optimizers/loss functions if you explain your choice in the report." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "id": "V9-Ino-e29w3" + }, + "outputs": [], + "source": [ + "# Define optimizer and loss\n", + "model = Word2Vec().cuda()\n", + "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", + "criterion = nn.NLLLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "V9-Ino-e29w3" - }, - "source": [ - "# Define optimizer and loss\n", - "optimizer = \n", - "criterion = " - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ckTfK78Ew8wI" - }, - "source": [ - "# 1.9 Training the model (3 points)\r\n", - "\r\n", - "As everything is prepared, implement a training loop that performs several passes of the data set through the model. You are free to do this as you please, but your code should:\r\n", - "\r\n", - "* Load the weights saved in 1.6 at the start of every execution of the code block\r\n", - "* Print the accumulated loss at least after every epoch (the accumulate loss should be reset after every epoch)\r\n", - "* Define a criterion for the training procedure to terminate if a certain loss value is reached. You can find the threshold by observing the loss for the development set.\r\n", - "\r\n", - "You can play around with the number of epochs and the learning rate." - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "=====================================\n", + "The Word2Vec model: \n", + "Word2Vec(\n", + " (w1): Linear(in_features=1097, out_features=300, bias=False)\n", + " (w2): Linear(in_features=300, out_features=1097, bias=True)\n", + " (soft): LogSoftmax(dim=1)\n", + ")\n", + "=====================================\n" + ] + } + ], + "source": [ + "print(\"=====================================\")\n", + "print(\"The Word2Vec model: \")\n", + "print(model)\n", + "print(\"=====================================\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ckTfK78Ew8wI" + }, + "source": [ + "# 1.9 Training the model (3 points)\n", + "\n", + "As everything is prepared, implement a training loop that performs several passes of the data set through the model. You are free to do this as you please, but your code should:\n", + "\n", + "* Load the weights saved in 1.6 at the start of every execution of the code block\n", + "* Print the accumulated loss at least after every epoch (the accumulate loss should be reset after every epoch)\n", + "* Define a criterion for the training procedure to terminate if a certain loss value is reached. You can find the threshold by observing the loss for the development set.\n", + "\n", + "You can play around with the number of epochs and the learning rate." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "'''Gets the corpus and creates the training data with the target and its context\n", + "and returns a dataframe containing them in terms of their indexes.\n", + "'''\n", + "def get_training_data(corpus,window):\n", + " t,c = [],[]\n", + " for sentence in corpus:\n", + " data = get_target_context(sentence,window)\n", + " for i,j in data:\n", + " x = get_onehot[index_word[i]]\n", + " t.append(x)\n", + " c.append(j)\n", + " t_data = pd.DataFrame(list(zip(t,c)),columns=[\"target\",\"context\"])\n", + " return t_data" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "id": "LbMGD5L0mLDx" + }, + "outputs": [], + "source": [ + "def train(traindata,batchsize):\n", + " losses = []\n", + " print(\"Training started\")\n", + " for epoch in range(1,epochs+1):\n", + " total_loss = []\n", + " for wt,wc in zip(DataLoader(traindata.target.values,batch_size=batchsize),\n", + " DataLoader(traindata.context.values,batch_size=batchsize)):\n", + " wt = wt.cuda()\n", + " wc = wc.cuda()\n", + " optimizer.zero_grad()\n", + " output = model(wt)\n", + " loss = criterion(output,wc)\n", + " total_loss.append(loss.item())\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " if epoch % 50 == 0 :\n", + " start = time.time()\n", + " print(\"===========================================================\")\n", + " print(\"Saving the model state\")\n", + " save_model(epoch)\n", + " end = str(datetime.timedelta(seconds = time.time()-start))\n", + " print(\"Model state saved. It was completed in {}\".format(end)) \n", + " print(\"===========================================================\")\n", + "\n", + " if np.mean(total_loss) < 1.2:\n", + " break;\n", + " print(\"At epoch {} the loss is {}\".format(epoch ,round(np.mean(total_loss),3)))\n", + " losses.append(np.mean(total_loss))\n", + "\n", + " plt.xlabel(\"Epochs\")\n", + " plt.ylabel(\"LOSS\")\n", + " plt.plot(losses)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "def save_model(epoch):\n", + " torch.save(model.state_dict(),\"epoch_{}.pt\".format(epoch))" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "scrolled": false + }, + "outputs": [ { - "cell_type": "code", - "metadata": { - "id": "LbMGD5L0mLDx" - }, - "source": [ - "# Define train procedure\n", - "\n", - "# load initial weights\n", - "\n", - "def train():\n", - " \n", - " print(\"Training started\")\n", - "\n", - "train()\n", - "\n", - "print(\"Training finished\")" - ], - "execution_count": null, - "outputs": [] + "name": "stdout", + "output_type": "stream", + "text": [ + "=================================================\n", + "Collecting training data\n", + "It took 0:00:00.318723 to collect the data\n", + "The training data has 4028 target-context pairs\n", + "===================================================\n", + "Training started\n", + "At epoch 1 the loss is 7.006\n", + "At epoch 2 the loss is 6.99\n", + "At epoch 3 the loss is 6.974\n", + "At epoch 4 the loss is 6.96\n", + "At epoch 5 the loss is 6.946\n", + "At epoch 6 the loss is 6.932\n", + "At epoch 7 the loss is 6.92\n", + "At epoch 8 the loss is 6.907\n", + "At epoch 9 the loss is 6.895\n", + "At epoch 10 the loss is 6.884\n", + "At epoch 11 the loss is 6.872\n", + "At epoch 12 the loss is 6.861\n", + "At epoch 13 the loss is 6.849\n", + "At epoch 14 the loss is 6.837\n", + "At epoch 15 the loss is 6.825\n", + "At epoch 16 the loss is 6.812\n", + "At epoch 17 the loss is 6.798\n", + "At epoch 18 the loss is 6.783\n", + "At epoch 19 the loss is 6.767\n", + "At epoch 20 the loss is 6.751\n", + "At epoch 21 the loss is 6.736\n", + "At epoch 22 the loss is 6.723\n", + "At epoch 23 the loss is 6.711\n", + "At epoch 24 the loss is 6.699\n", + "At epoch 25 the loss is 6.688\n", + "At epoch 26 the loss is 6.677\n", + "At epoch 27 the loss is 6.666\n", + "At epoch 28 the loss is 6.654\n", + "At epoch 29 the loss is 6.643\n", + "At epoch 30 the loss is 6.631\n", + "At epoch 31 the loss is 6.618\n", + "At epoch 32 the loss is 6.605\n", + "At epoch 33 the loss is 6.591\n", + "At epoch 34 the loss is 6.576\n", + "At epoch 35 the loss is 6.561\n", + "At epoch 36 the loss is 6.544\n", + "At epoch 37 the loss is 6.527\n", + "At epoch 38 the loss is 6.508\n", + "At epoch 39 the loss is 6.489\n", + "At epoch 40 the loss is 6.468\n", + "At epoch 41 the loss is 6.446\n", + "At epoch 42 the loss is 6.423\n", + "At epoch 43 the loss is 6.398\n", + "At epoch 44 the loss is 6.373\n", + "At epoch 45 the loss is 6.345\n", + "At epoch 46 the loss is 6.317\n", + "At epoch 47 the loss is 6.287\n", + "At epoch 48 the loss is 6.256\n", + "At epoch 49 the loss is 6.224\n", + "===========================================================\n", + "Saving the model state\n", + "Model state saved. It was completed in 0:00:00.011237\n", + "===========================================================\n", + "At epoch 50 the loss is 6.19\n", + "At epoch 51 the loss is 6.155\n", + "At epoch 52 the loss is 6.118\n", + "At epoch 53 the loss is 6.08\n", + "At epoch 54 the loss is 6.04\n", + "At epoch 55 the loss is 5.999\n", + "At epoch 56 the loss is 5.956\n", + "At epoch 57 the loss is 5.912\n", + "At epoch 58 the loss is 5.867\n", + "At epoch 59 the loss is 5.82\n", + "At epoch 60 the loss is 5.772\n", + "At epoch 61 the loss is 5.722\n", + "At epoch 62 the loss is 5.671\n", + "At epoch 63 the loss is 5.619\n", + "At epoch 64 the loss is 5.565\n", + "At epoch 65 the loss is 5.51\n", + "At epoch 66 the loss is 5.454\n", + "At epoch 67 the loss is 5.397\n", + "At epoch 68 the loss is 5.339\n", + "At epoch 69 the loss is 5.28\n", + "At epoch 70 the loss is 5.219\n", + "At epoch 71 the loss is 5.159\n", + "At epoch 72 the loss is 5.097\n", + "At epoch 73 the loss is 5.035\n", + "At epoch 74 the loss is 4.972\n", + "At epoch 75 the loss is 4.909\n", + "At epoch 76 the loss is 4.845\n", + "At epoch 77 the loss is 4.781\n", + "At epoch 78 the loss is 4.717\n", + "At epoch 79 the loss is 4.652\n", + "At epoch 80 the loss is 4.587\n", + "At epoch 81 the loss is 4.522\n", + "At epoch 82 the loss is 4.458\n", + "At epoch 83 the loss is 4.393\n", + "At epoch 84 the loss is 4.328\n", + "At epoch 85 the loss is 4.263\n", + "At epoch 86 the loss is 4.198\n", + "At epoch 87 the loss is 4.134\n", + "At epoch 88 the loss is 4.07\n", + "At epoch 89 the loss is 4.005\n", + "At epoch 90 the loss is 3.942\n", + "At epoch 91 the loss is 3.878\n", + "At epoch 92 the loss is 3.815\n", + "At epoch 93 the loss is 3.752\n", + "At epoch 94 the loss is 3.69\n", + "At epoch 95 the loss is 3.628\n", + "At epoch 96 the loss is 3.567\n", + "At epoch 97 the loss is 3.506\n", + "At epoch 98 the loss is 3.446\n", + "At epoch 99 the loss is 3.386\n", + "===========================================================\n", + "Saving the model state\n", + "Model state saved. It was completed in 0:00:00.006575\n", + "===========================================================\n", + "At epoch 100 the loss is 3.326\n", + "At epoch 101 the loss is 3.268\n", + "At epoch 102 the loss is 3.21\n", + "At epoch 103 the loss is 3.153\n", + "At epoch 104 the loss is 3.096\n", + "At epoch 105 the loss is 3.04\n", + "At epoch 106 the loss is 2.985\n", + "At epoch 107 the loss is 2.931\n", + "At epoch 108 the loss is 2.877\n", + "At epoch 109 the loss is 2.825\n", + "At epoch 110 the loss is 2.773\n", + "At epoch 111 the loss is 2.722\n", + "At epoch 112 the loss is 2.673\n", + "At epoch 113 the loss is 2.624\n", + "At epoch 114 the loss is 2.576\n", + "At epoch 115 the loss is 2.53\n", + "At epoch 116 the loss is 2.484\n", + "At epoch 117 the loss is 2.44\n", + "At epoch 118 the loss is 2.398\n", + "At epoch 119 the loss is 2.356\n", + "At epoch 120 the loss is 2.316\n", + "At epoch 121 the loss is 2.278\n", + "At epoch 122 the loss is 2.241\n", + "At epoch 123 the loss is 2.206\n", + "At epoch 124 the loss is 2.172\n", + "At epoch 125 the loss is 2.14\n", + "At epoch 126 the loss is 2.109\n", + "At epoch 127 the loss is 2.081\n", + "At epoch 128 the loss is 2.053\n", + "At epoch 129 the loss is 2.028\n", + "At epoch 130 the loss is 2.003\n", + "At epoch 131 the loss is 1.98\n", + "At epoch 132 the loss is 1.959\n", + "At epoch 133 the loss is 1.938\n", + "At epoch 134 the loss is 1.919\n", + "At epoch 135 the loss is 1.9\n", + "At epoch 136 the loss is 1.883\n", + "At epoch 137 the loss is 1.867\n", + "At epoch 138 the loss is 1.851\n", + "At epoch 139 the loss is 1.836\n", + "At epoch 140 the loss is 1.822\n", + "At epoch 141 the loss is 1.809\n", + "At epoch 142 the loss is 1.797\n", + "At epoch 143 the loss is 1.785\n", + "At epoch 144 the loss is 1.773\n", + "At epoch 145 the loss is 1.763\n", + "At epoch 146 the loss is 1.752\n", + "At epoch 147 the loss is 1.742\n", + "At epoch 148 the loss is 1.733\n", + "At epoch 149 the loss is 1.724\n", + "===========================================================\n", + "Saving the model state\n", + "Model state saved. It was completed in 0:00:00.010483\n", + "===========================================================\n", + "At epoch 150 the loss is 1.716\n", + "At epoch 151 the loss is 1.708\n", + "At epoch 152 the loss is 1.7\n", + "At epoch 153 the loss is 1.693\n", + "At epoch 154 the loss is 1.686\n", + "At epoch 155 the loss is 1.679\n", + "At epoch 156 the loss is 1.673\n", + "At epoch 157 the loss is 1.667\n", + "At epoch 158 the loss is 1.661\n", + "At epoch 159 the loss is 1.656\n", + "At epoch 160 the loss is 1.65\n", + "At epoch 161 the loss is 1.645\n", + "At epoch 162 the loss is 1.64\n", + "At epoch 163 the loss is 1.636\n", + "At epoch 164 the loss is 1.631\n", + "At epoch 165 the loss is 1.627\n", + "At epoch 166 the loss is 1.623\n", + "At epoch 167 the loss is 1.619\n", + "At epoch 168 the loss is 1.616\n", + "At epoch 169 the loss is 1.612\n", + "At epoch 170 the loss is 1.609\n", + "At epoch 171 the loss is 1.605\n", + "At epoch 172 the loss is 1.602\n", + "At epoch 173 the loss is 1.599\n", + "At epoch 174 the loss is 1.596\n", + "At epoch 175 the loss is 1.593\n", + "At epoch 176 the loss is 1.591\n", + "At epoch 177 the loss is 1.588\n", + "At epoch 178 the loss is 1.586\n", + "At epoch 179 the loss is 1.583\n", + "At epoch 180 the loss is 1.581\n", + "At epoch 181 the loss is 1.579\n", + "At epoch 182 the loss is 1.577\n", + "At epoch 183 the loss is 1.574\n", + "At epoch 184 the loss is 1.572\n", + "At epoch 185 the loss is 1.571\n", + "At epoch 186 the loss is 1.569\n", + "At epoch 187 the loss is 1.567\n", + "At epoch 188 the loss is 1.565\n", + "At epoch 189 the loss is 1.564\n", + "At epoch 190 the loss is 1.562\n", + "At epoch 191 the loss is 1.56\n", + "At epoch 192 the loss is 1.559\n", + "At epoch 193 the loss is 1.558\n", + "At epoch 194 the loss is 1.556\n", + "At epoch 195 the loss is 1.555\n", + "At epoch 196 the loss is 1.553\n", + "At epoch 197 the loss is 1.552\n", + "At epoch 198 the loss is 1.551\n", + "At epoch 199 the loss is 1.55\n", + "===========================================================\n", + "Saving the model state\n", + "Model state saved. It was completed in 0:00:00.007587\n", + "===========================================================\n", + "At epoch 200 the loss is 1.549\n", + "At epoch 201 the loss is 1.548\n", + "At epoch 202 the loss is 1.547\n", + "At epoch 203 the loss is 1.546\n", + "At epoch 204 the loss is 1.545\n", + "At epoch 205 the loss is 1.544\n", + "At epoch 206 the loss is 1.543\n", + "At epoch 207 the loss is 1.542\n", + "At epoch 208 the loss is 1.541\n", + "At epoch 209 the loss is 1.54\n", + "At epoch 210 the loss is 1.539\n", + "At epoch 211 the loss is 1.538\n", + "At epoch 212 the loss is 1.537\n", + "At epoch 213 the loss is 1.537\n", + "At epoch 214 the loss is 1.536\n", + "At epoch 215 the loss is 1.535\n", + "At epoch 216 the loss is 1.534\n", + "At epoch 217 the loss is 1.534\n", + "At epoch 218 the loss is 1.533\n", + "At epoch 219 the loss is 1.532\n", + "At epoch 220 the loss is 1.532\n", + "At epoch 221 the loss is 1.531\n", + "At epoch 222 the loss is 1.53\n", + "At epoch 223 the loss is 1.53\n", + "At epoch 224 the loss is 1.529\n", + "At epoch 225 the loss is 1.529\n", + "At epoch 226 the loss is 1.528\n", + "At epoch 227 the loss is 1.528\n", + "At epoch 228 the loss is 1.527\n", + "At epoch 229 the loss is 1.527\n", + "At epoch 230 the loss is 1.526\n", + "At epoch 231 the loss is 1.525\n", + "At epoch 232 the loss is 1.525\n", + "At epoch 233 the loss is 1.524\n", + "At epoch 234 the loss is 1.524\n", + "At epoch 235 the loss is 1.524\n", + "At epoch 236 the loss is 1.523\n", + "At epoch 237 the loss is 1.523\n", + "At epoch 238 the loss is 1.522\n", + "At epoch 239 the loss is 1.522\n", + "At epoch 240 the loss is 1.521\n", + "At epoch 241 the loss is 1.521\n", + "At epoch 242 the loss is 1.521\n", + "At epoch 243 the loss is 1.52\n", + "At epoch 244 the loss is 1.52\n", + "At epoch 245 the loss is 1.519\n", + "At epoch 246 the loss is 1.519\n", + "At epoch 247 the loss is 1.519\n", + "At epoch 248 the loss is 1.518\n", + "At epoch 249 the loss is 1.518\n", + "===========================================================\n", + "Saving the model state\n", + "Model state saved. It was completed in 0:00:00.005157\n", + "===========================================================\n", + "At epoch 250 the loss is 1.518\n", + "At epoch 251 the loss is 1.517\n", + "At epoch 252 the loss is 1.517\n", + "At epoch 253 the loss is 1.517\n", + "At epoch 254 the loss is 1.516\n", + "At epoch 255 the loss is 1.516\n", + "At epoch 256 the loss is 1.516\n", + "At epoch 257 the loss is 1.515\n", + "At epoch 258 the loss is 1.515\n", + "At epoch 259 the loss is 1.515\n", + "At epoch 260 the loss is 1.514\n", + "At epoch 261 the loss is 1.514\n", + "At epoch 262 the loss is 1.514\n", + "At epoch 263 the loss is 1.514\n", + "At epoch 264 the loss is 1.513\n", + "At epoch 265 the loss is 1.513\n", + "At epoch 266 the loss is 1.513\n", + "At epoch 267 the loss is 1.512\n", + "At epoch 268 the loss is 1.512\n", + "At epoch 269 the loss is 1.512\n", + "At epoch 270 the loss is 1.512\n", + "At epoch 271 the loss is 1.511\n", + "At epoch 272 the loss is 1.511\n", + "At epoch 273 the loss is 1.511\n", + "At epoch 274 the loss is 1.511\n", + "At epoch 275 the loss is 1.51\n", + "At epoch 276 the loss is 1.51\n", + "At epoch 277 the loss is 1.51\n", + "At epoch 278 the loss is 1.51\n", + "At epoch 279 the loss is 1.51\n", + "At epoch 280 the loss is 1.509\n", + "At epoch 281 the loss is 1.509\n", + "At epoch 282 the loss is 1.509\n", + "At epoch 283 the loss is 1.509\n", + "At epoch 284 the loss is 1.509\n", + "At epoch 285 the loss is 1.508\n", + "At epoch 286 the loss is 1.508\n", + "At epoch 287 the loss is 1.508\n", + "At epoch 288 the loss is 1.508\n", + "At epoch 289 the loss is 1.508\n", + "At epoch 290 the loss is 1.507\n", + "At epoch 291 the loss is 1.507\n", + "At epoch 292 the loss is 1.507\n", + "At epoch 293 the loss is 1.507\n", + "At epoch 294 the loss is 1.507\n", + "At epoch 295 the loss is 1.506\n", + "At epoch 296 the loss is 1.506\n", + "At epoch 297 the loss is 1.506\n", + "At epoch 298 the loss is 1.506\n", + "At epoch 299 the loss is 1.506\n", + "===========================================================\n", + "Saving the model state\n", + "Model state saved. It was completed in 0:00:00.009836\n", + "===========================================================\n", + "At epoch 300 the loss is 1.506\n", + "Training finished.\n", + "It took 0:00:40.983541 to finish training the model\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "BgQkaYstyj0Q" - }, - "source": [ - "# 1.10 Train on the full dataset (0.5 points)\r\n", - "\r\n", - "Now, go back to 1.1 and remove the restriction on the number of sentences in your corpus. Then, reexecute code blocks 1.2, 1.3 and 1.6 (or those relevant if you created additional ones). \r\n", - "\r\n", - "* Then, retrain your model on the complete dataset.\r\n", - "\r\n", - "* Now, the input weights of the model contain the desired word embeddings! Save them together with the corresponding vocabulary items (Pytorch provides a nice [functionality](https://pytorch.org/tutorials/beginner/saving_loading_models.html) for this)." + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAEGCAYAAABvtY4XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgqUlEQVR4nO3dd3gd5Zn+8e9zjqoly1UW7pKxjbHBGCMMGENiO6YlWVJIISQbAgkhMQR+7CYhu79NsrvZluxmCSlwOZhO6HUTwtIJxWDLBRcM7hUXuUuWrfrsH2dkhJEsy9JoTrk/13WuM2dmdOYZj3yf0XvmfcfcHRERST+xqAsQEZFwKOBFRNKUAl5EJE0p4EVE0pQCXkQkTWVFXUBL/fv399LS0qjLEBFJGfPnz9/h7sWtLUuqgC8tLaWioiLqMkREUoaZrW9rmZpoRETSlAJeRCRNKeBFRNKUAl5EJE0p4EVE0lRoAW9mJ5jZohaPfWZ2fVjbExGRDwvtMkl3fw+YAGBmcWAz8HhY2xMRkQ/rriaa6cBqd2/zes3OuPmFlVSs2xXGW4uIpKzuCvgvA/e3tsDMrjKzCjOrqKys7PAb7z1Qz31vreeSW+fw7XsqWF1Z3dlaRUTSgoV9ww8zywHeB8a5+7YjrVteXu7H0pO1pq6B2a+u5dZXVnOwoYnPTxzMVecez8gBhcdYtYhIajCz+e5e3tqy7hiq4EJgQXvh3hk9crK4dvooLj1jGL95cRX3z93AQxWbmDG2hKs/djynDe8T1qZFRJJWd5zBPwD8r7vf0d66x3oGf7id1bXcNWc9d89Zx56aek4b3ofLJ5dywUnHkR3XlaEikj6OdAYfasCbWQGwARjh7nvbW7+rAr7Z/toGHqrYyJ1vrGP9zhpKinL52pnDuXTSMPoV5nbZdkREohJZwHdUVwd8s6Ym5+UV27nj9XW8unIHOVkx/uqUQVw+uZSTBvfq8u2JiHSXqNvgIxeLGdPGlDBtTAmrtldx1xvreXTBJh6Zv4nTS/tw+eQyzh9XQpaab0QkjWTEGXxr9h6o5+GKjdw1Zx0bdx1gSJ98bpgxmosnDCYes26pQUSkszK+ieZIGpucF5Zv4+YXV7J08z5GlxTyt+edwIyxJZgp6EUkuR0p4DO+TSIeM84bdxxPzZzCb78ykYZG56p75vO5W97g1ZWVJNMHoIhIR2T8GfzhGhqbeGT+Jn71wkq27D3IiOICLjtjOJdMHEKvHtmR1iYicjg10RyDg/WNPL1kC/e+uZ4FG/aQG1x587WzhjN+SO+oyxMRARTwnbbs/b3c99YGnli4mZq6Rk4Z0ouvnjmcT58yiLzseNTliUgGU8B3kaqD9Ty+cDN3z1nPqu3V9O6RzZfKh3LllDIGFOVFXZ6IZCAFfBdzd95cs4t731zPM8u2Eo8ZX5k0jG9/bAQDe+VHXZ6IZJCM7+jU1cyMs47vx1nH92P9zv387qXV3Pvmev4wdwPfOqeM7358JAW5+qcVkWjpDL6LbNxVwy+fW8HjCzdTUpTLjy48kYsnDNK19CISKl0H3w2G9u3Bf39pAo9+ZzIlRXlc/+AivnlXBZVVtVGXJiIZSgHfxU4b3ocnvns2P/7UWF5btYPzb/oLzyzdEnVZIpKBFPAhiMWMK6aU8cdrpzCodx5X37uAf/7jOzQ0NkVdmohkEAV8iEaV9OSx75zN5ZNLmf3aWr5x5zz21tRHXZaIZAgFfMhysmL89K/G8e+fO5k31+zkM797nfU790ddlohkAAV8N/nypGH84Vtnsqemji/cOoeV26qiLklE0pwCvhudXtqXB799Fg58adabLN3c7l0MRUSOmQK+m40u6clD3z6LvKwYl/7+TRZv2hN1SSKSphTwESjrX8BDV59Fr/xsLr9jHmsqq6MuSUTSkAI+IkP69ODuKyZhwNdmz2XbvoNRlyQiaUYBH6ERxYXc+Y1J7Kmp4+u3z2XvAV1CKSJdRwEfsZOH9GLWX5ezurKa792/kMam5BkbSERSmwI+CZw9sj//dPFJvLKikv989r2oyxGRNKExbZPEpZOGsWTzXm55eTXjBhXxqfGDoi5JRFKczuCTyE8/PY7Thvfh+w8v5t2t+6IuR0RSnAI+ieRkxbjlsokU5mVx7R8WcrC+MeqSRCSFhRrwZtbbzB4xs3fNbLmZnRXm9tLBgKI8fvnFU1i5vZp/fXp51OWISAoL+wz+V8Az7j4GOAVQYh2Fc0YV880pZdw9Zz0vLN8WdTkikqJCC3gz6wWcC8wGcPc6d98T1vbSzfcvOIGxA4v4/iOL2V6lTlAi0nFhnsGXAZXAHWa20MxuM7OCw1cys6vMrMLMKiorK0MsJ7XkZsW5+dIJ1NQ18HePLSWZ7p0rIqkhzIDPAiYCt7j7qcB+4MbDV3L3We5e7u7lxcXFIZaTekYO6MkNM0bz/PJtPL1ka9TliEiKCTPgNwGb3P2t4PUjJAJfOuCKs8s4eXAvfvLUUvbU1EVdjoikkNAC3t23AhvN7IRg1nTgnbC2l66y4jH+4/Pj2V1Tz8/+pO+oReTohX0VzbXAfWa2GJgA/GvI20tLYwcV8e1zR/DI/E28ulLfU4jI0Qk14N19UdC+Pt7dP+Puu8PcXjr73vRRlPUv4CdPLqOuoSnqckQkBagna4rIy47z40+PZc2O/dz5xtqoyxGRFKCATyFTTxjA9DED+NXzK9muG4SISDsU8CnmHz41lvpG59+feTfqUkQkySngU0xp/wKuPKeMxxZsZv76XVGXIyJJTAGfgq6ZOpKSolz++Y/L1cNVRNqkgE9BBblZ3DBjNIs27uGZperhKiKtU8CnqM9PHMKoAYX8/H/fo75Rl02KyEcp4FNUVjzGDy4Yw9od+3lw3saoyxGRJKSAT2GfOHEAp5f24abnV7K/tiHqckQkySjgU5iZceOFY9hRXcvs19T5SUQ+TAGf4k4b3pcZY0v4/atr2FtTH3U5IpJEFPBp4P99YjRVBxuY/dqaqEsRkSSigE8DYwcVceFJx3H76+s0ZryIHKKATxPXfWIU1bUN3Paq2uJFJEEBnybGHFfEJ8cP5I7X17Jrv87iRUQBn1aunz6KmvpGfv+q2uJFRAGfVkaV9OSTJw/knjnrdUWNiCjg0813Pz6S6toG7p6zLupSRCRiCvg0M3ZQEdPGDOD219dSU6ferSKZTAGfhr778ePZXVPPA3M1Ro1IJlPAp6Hy0r5MKuvLrL+s0Q26RTKYAj5NzZw6kq37DvL4wk1RlyIiEVHAp6lzR/Vn3KAibn1lDY1NuuuTSCZSwKcpM2Pm1JGs3bGfPy/dEnU5IhIBBXwaO3/ccYwoLuB3L63WvVtFMpACPo3FY8ZV54zgnS37mLN6Z9TliEg3U8Cnuc+cOpj+hTnM0vAFIhkn1IA3s3VmtsTMFplZRZjbktblZcf567NKefm9SlZsq4q6HBHpRt1xBj/V3Se4e3k3bEta8dUzh5OXHeM2ncWLZBQ10WSAvgU5XHLaEJ5Y+D7bqw5GXY6IdJOwA96BZ81svpld1doKZnaVmVWYWUVlZWXI5WSuK6eMoL6pibvfWB91KSLSTcIO+CnuPhG4EJhpZucevoK7z3L3cncvLy4uDrmczFXWv4DzxpZwz5vrNQiZSIYINeDdfXPwvB14HJgU5vbkyL51zgj2Hqjn4QoNXyCSCUILeDMrMLOezdPAecDSsLYn7TtteB9OHdab2a+t1fAFIhkgzDP4EuA1M3sbmAv8yd2fCXF70g6zRMenDbtqeHbZ1qjLEZGQZYX1xu6+BjglrPeXY3PeuOMY1rcHs15dw4UnD4y6HBEJkS6TzDDxmHHF2aUs3LCH+et3R12OiIRIAZ+BvlA+lJ55Wdz+2tqoSxGRECngM1BBbhZfOWMYf166hY27aqIuR0RCooDPUJdPLiVmxl1vrIu6FBEJiQI+Qw3slc9FJw/kgXkbqTpYH3U5IhICBXwG++Y5ZVTXNvDgvI1RlyIiIVDAZ7DxQ3ozqbQvd7y+jobGpqjLEZEupoDPcFeeU8bmPQd49p1tUZciIl1MAZ/hPnFiCcP69tBY8SJpSAGf4Zo7Pi3YsIcFG9TxSSSdKODlUMen2er4JJJWFPCS6Pg0aRh/XqKOTyLp5IgBb2afNrPhLV7/2MzeNrOnzKws/PKku3x9cimmjk8iaaW9M/h/ASoBzOxTwFeBK4CngFvDLU2606De+XxSHZ9E0kp7Ae/u3vw3++eA2e4+391vA3R/vTTT3PHpId3xSSQttBfwZmaFZhYDpgMvtFiWF15ZEoXxQ3pzemkf7nh9rTo+iaSB9gL+JmARUAEsd/cKADM7FdgSamUSiSunjGDTbnV8EkkHRwx4d78d+BhwJXBRi0VbgG+EWJdEZMbYRMcnXTIpkvrau4pmOFDt7gvdvcnMpprZr4CvALqpZxpq7vg0f/1udXwSSXHtNdE8BBQAmNkE4GFgA4l7rf4u1MokMur4JJIe2gv4fHd/P5j+KnC7u/8XieaZSaFWJpFp2fFp0251fBJJVe1eRdNiehrBVTTurkss0pw6PomkvvYC/kUzeyhod+8DvAhgZgOBurCLk+gc6vg0Vx2fRFJVewF/PfAYsA6Y4u7N/9OPA/4+vLIkGVw5pYwqdXwSSVntXSbp7v4A8ARwqpl9ysxGBFfV/G+3VCiROWXoBx2fGps86nJEpIPau0yyyMweAp4nMQbNFcDzZvawmRV1R4ESrUMdn5bpqliRVNNeE83NwDvAKHf/nLt/DjgeWAL8JuziJHrNHZ9u0yWTIimnvYA/291/2vKqmaDZ5p+As45mA2YWN7OFZvbHzhQq0YjHjG8EHZ8WquOTSErpzA0/rP1VALgOWN6J7UjE1PFJJDW1F/BvBDf5+FCYm9k/AHPae3MzGwJ8Erjt2EuUqBU2d3xaulUdn0RSSHsBfy1wMrDKzB4NHqtJDFVwzVG8/03AD4A2O0aZ2VVmVmFmFZWVlUdZtnS3r08uBVDHJ5EU0t5lkvvc/QvAecCdweM8d7+EdkaTDO4Atd3d57ezjVnuXu7u5cXFuodIshrUO5+Lgo5P1bUNUZcjIkfhqNrg3X21u/9P8FgdzL6hnR87G/grM1sHPABMM7N7j71Uidqhjk/zNkZdiogchdC+ZHX3H7n7EHcvBb4MvOjuX+3E9iRiE4KOT7er45NISuhMwOt/eAa6ckqZOj6JpIj2erJWmdm+Vh5VwKCj3Yi7v+zun+p0tRK5GWOPY2jffF0yKZIC2vuStae7F7Xy6OnuWd1VpCSPxB2fyqhQxyeRpNeZJhrJUOr4JJIaFPDSYYW5WVwadHzavOdA1OWISBsU8HJM1PFJJPkp4OWYDA46Pt3/1gZ1fBJJUgp4OWbq+CSS3BTwcswmDO1N+fA+3PGGOj6JJCMFvHTKN88pY+OuAzz3jjo+iSQbBbx0SnPHp9te1SWTIslGAS+dEo8Z35ic6Pi0aOOeqMsRkRYU8NJpXzx9KD1z1fFJJNko4KXTCnOzuPSMYTy9ZIs6PokkEQW8dInLJ5diwO//sibqUkQkoICXLjGodz6fPXUwD8zbwI7q2qjLEREU8NKFvvPx46ltaOJ2tcWLJAUFvHSZEcWFXHTyQO6Zs569B+qjLkck4yngpUvN/PhIqmobuGfOuqhLEcl4CnjpUmMHFTFtzABmv7aWmjoNQiYSJQW8dLmZU0eyu6ae++dqEDKRKCngpcudNrwPZ47oy6y/rKa2oTHqckQylgJeQnHN1FFs21fLYws2R12KSMZSwEsozh7Zj1OG9OKWl1fT0NgUdTkiGUkBL6EwM2ZOHcmGXTX8z+L3oy5HJCMp4CU0nzixhBMHFnHzC6t0Fi8SAQW8hCYWM66bPoq1O/bz5CKdxYt0NwW8hOr8cSWMHVjEr19cqbN4kW6mgJdQmRnXf2IU63bW8PhCXVEj0p1CC3gzyzOzuWb2tpktM7N/DGtbktxmjC3hpMFF/PrFVdTrLF6k24R5Bl8LTHP3U4AJwAVmdmaI25MkZWZcP300G3bV8LiuixfpNqEFvCdUBy+zg4eHtT1JbtNPHMD4Ib24+cWVOosX6SahtsGbWdzMFgHbgefc/a1W1rnKzCrMrKKysjLMciRCzW3xm3Yf4OGKTVGXI5IRQg14d2909wnAEGCSmZ3Uyjqz3L3c3cuLi4vDLEciNvWEAUwc1ptfvbCCA3Uao0YkbN1yFY277wFeAi7oju1JcjIzbrzwRLbtq+XON9ZFXY5I2gvzKppiM+sdTOcDM4B3w9qepIZJZX2ZekIxt7y8ir01uuuTSJjCPIMfCLxkZouBeSTa4P8Y4vYkRfzggjFU1Tbwu1dWRV2KSFrLCuuN3X0xcGpY7y+p68SBRXxmwmDufH0dl08uZWCv/KhLEklL6skqkbhhxmia3PnV8yujLkUkbSngJRJD+/bgsjOG81DFRlZtr27/B0SkwxTwEplrp42kR04W//7n5VGXIpKWFPASmX6FucycOpLnl2/n1ZXq5CbS1RTwEqkrppQyrG8P/ul/3tFwwiJdTAEvkcrNivN3F53Iyu3V3PfWhqjLEUkrCniJ3PnjSph8fD/++/kV7Kmpi7ockbShgJfImRk//vRY9h2o5yZdNinSZRTwkhTGHFfEpZOGcc+b61mxrSrqckTSggJeksYNM0ZTkBPnJ08uw123DhDpLAW8JI1+hbn88MIxzFmzk8d05yeRTlPAS1K59PRhTBzWm5/96R127dcXriKdoYCXpBKLGf/6uZOpOtjAvz2tHq4inaGAl6Qz5rgivnXuCB6ev4k5q3dGXY5IylLAS1L63rRRDOvbgxsfW8z+2oaoyxFJSQp4SUr5OXF+ccl4Nuyq4d80GJnIMVHAS9I6Y0Q/vjmljHvf3MArKzQYmUhHKeAlqf3NeScwakAhP3jkbd3DVaSDFPCS1PKy4/zyixPYWV3H3z2xRB2gRDpAAS9J7+QhvbjhvNH8afEW7p6zPupyRFKGAl5SwtXnHs/0MQP42Z/eYcGG3VGXI5ISFPCSEmIx45dfnEBJUR7X3LdAvVxFjoICXlJGrx7Z3HLZaeyoruPa+xdQrztAiRyRAl5SyslDevEvnz2J11ft5P8/vlRfuoocQVbUBYh01BfKh7JhVw2/fnEVQ/vmc820UVGXJJKUFPCSkm6YMZqNu2r4z2dX0DMvm69PLo26JJGko4CXlGRm/OILp7C/rpGfPLWM/Ow4Xzx9aNRliSSV0NrgzWyomb1kZu+Y2TIzuy6sbUlmyo7H+M1XTuXc0cX84NHF3PXGuqhLEkkqYX7J2gD8jbuPBc4EZprZ2BC3JxkoNyvOrK+dxnljS/jJU8u46fkV+uJVJBBawLv7FndfEExXAcuBwWFtTzJXXnac3102kUtOG8JNz6/kmvsXUlOnIYZFuuUySTMrBU4F3mpl2VVmVmFmFZWVGjFQjk1WPMYvLhnPjy4cw9NLtnDJLXPYtLsm6rJEIhV6wJtZIfAocL277zt8ubvPcvdydy8vLi4OuxxJY2bGtz92PHdcfjobd9fwyZtf46m334+6LJHIhBrwZpZNItzvc/fHwtyWSLOPnzCAp66ZQln/Ar53/0Jm/mEBuzW0gWSgMK+iMWA2sNzdfxnWdkRaU9a/gEeuPovvn38Czy7byoz/foUH522gqUlfwErmCPMM/mzga8A0M1sUPC4KcXsiH5IVjzFz6kienDmF4f0K+OGjS7j4t68zb92uqEsT6RaWTJeUlZeXe0VFRdRlSBpyd556+33+7el32brvIOeM6s9100dRXto36tJEOsXM5rt7eavLFPCSSWrqGrhnznpm/WUNO/fXceaIvlxxdhnTTywhHrOoyxPpMAW8yGFq6hr4w1sbmP3aWrbsPcjg3vlcduYwPnvqYAb2yo+6PJGjpoAXaUNDYxPPvbONu+esZ86anZjBGWV9+cyEwVx40kB69ciOukSRI1LAixyFtTv28+SizTy56H3W7thPdtyYVNaXqScMYNqYAYwoLoy6RJGPUMCLdIC7s2TzXv60eAsvvrudldurARjerwdnjejH6aV9mVTWlyF98klcDSwSHQW8SCds3FXDy+9t55UVlcxdu4t9BxPj3Azslcf4Ib0YN6gXJw0u4qRBvRhQlBdxtZJpFPAiXaSpyVmxvYq5a3cxb91ulm3ey5od+w8t71+Yy+iSQkYUFzCif+L5+OJCBvXO11U6EgoFvEiIqg7Ws3xLFUs372XZ+/tYVVnNmspqqg5+MKJlblaMwX3yGdw7n0G98hnUO59BvfOC53xKinLpkaP770jHHSng9Rsl0kk987KZVJZol2/m7uyormNNZTVrduxnTWU1m/ccYPOeg7y7dTuVVbUfeZ/87Dj9CnPoX5hL/8Ic+hXk0q8wh36FufQtyKYoL5ui/ObnLIrysumRE9f3ANImBbxICMyM4p65FPfM5YwR/T6yvLahka17D/L+noNs3nOAyqpadlbXsnN/HTuqa3l/z0EWb9rLrv11NBxh/Jx4zCjKy6JnEPo9cxOhn5cTp0d2nB45cfJzshLP2XHycxLzmufnZ8fJzYqR0/yIx8jNjpEbjx+ap6al1KWAF4lAblac4f0KGN6v4IjrNTU5ew/Us+dAPfsO1LPvYD1VBxsOTe870BA817PvYANVB+vZVlVPTV0jB+oaDz3XNTYdc61ZMfvIB0BOPEZOVvAhEDfiMSM7nvgwyIrFyD5sXvPrD5Z9dJ2sj6wfIytmxGJG3IyY8cF0DGKW+PmY2aHpeCzx4RoPXpslPgQTPxOs18rPJ55bvH+wzgfvk5ofcgp4kSQWixl9CnLoU5DTqfdpaGyipr6Rg0Ho19Q1cqC+gZq6RuoamhKPxiZqGxKPuhaP2obGQ8vrWiyvDZY1NjkNjc7+hgYagumGpqbg2WlobEo8tzKdSoN7mhF8mCQ+RIwPXsfMwD782oL1Dr0meB374HXzev0Kcnno6rO6vGYFvEgGyIrHKIrHKMpLrp65TUHYNzY59U1NNDYGz4c+KBLL3J1GT0w3NUFT8LopWN7kwbwmp8k9mOajP+uJnz/0sx78bNNHf7b5/RqbHCfxvYoH22ny4HWwD83bBw69R2KdxHptvg7W75kbThQr4EUkMrGYkRO08ecTj7ia9NMt92QVEZHup4AXEUlTCngRkTSlgBcRSVMKeBGRNKWAFxFJUwp4EZE0pYAXEUlTSTVcsJlVAuuP8cf7Azu6sJwoaV+ST7rsB2hfktWx7stwdy9ubUFSBXxnmFlFW2MipxrtS/JJl/0A7UuyCmNf1EQjIpKmFPAiImkqnQJ+VtQFdCHtS/JJl/0A7Uuy6vJ9SZs2eBER+bB0OoMXEZEWFPAiImkq5QPezC4ws/fMbJWZ3Rh1PR1lZuvMbImZLTKzimBeXzN7zsxWBs99oq6zNWZ2u5ltN7OlLea1Wrsl3Bwcp8VmNjG6yj+qjX35qZltDo7NIjO7qMWyHwX78p6ZnR9N1a0zs6Fm9pKZvWNmy8zsumB+yh2bI+xLyh0bM8szs7lm9nawL/8YzC8zs7eCmh80s5xgfm7welWwvLTDG03chio1H0AcWA2MAHKAt4GxUdfVwX1YB/Q/bN7PgRuD6RuB/4i6zjZqPxeYCCxtr3bgIuDPgAFnAm9FXf9R7MtPgb9tZd2xwe9aLlAW/A7Go96HFvUNBCYG0z2BFUHNKXdsjrAvKXdsgn/fwmA6G3gr+Pd+CPhyMP9W4DvB9HeBW4PpLwMPdnSbqX4GPwlY5e5r3L0OeAC4OOKausLFwF3B9F3AZ6IrpW3u/hdg12Gz26r9YuBuT3gT6G1mA7ul0KPQxr605WLgAXevdfe1wCoSv4tJwd23uPuCYLoKWA4MJgWPzRH2pS1Je2yCf9/q4GV28HBgGvBIMP/w49J8vB4BppuZdWSbqR7wg4GNLV5v4sgHPxk58KyZzTezq4J5Je6+JZjeCpREU9oxaav2VD1W1wTNFre3aCpLmX0J/qw/lcTZYkofm8P2BVLw2JhZ3MwWAduB50j8hbHH3RuCVVrWe2hfguV7gX4d2V6qB3w6mOLuE4ELgZlmdm7LhZ74+ywlr2VN5doDtwDHAxOALcB/RVpNB5lZIfAocL2772u5LNWOTSv7kpLHxt0b3X0CMITEXxZjwtxeqgf8ZmBoi9dDgnkpw903B8/bgcdJHPRtzX8iB8/bo6uww9qqPeWOlbtvC/5DNgG/54M/9ZN+X8wsm0Qg3ufujwWzU/LYtLYvqXxsANx9D/AScBaJJrGsYFHLeg/tS7C8F7CzI9tJ9YCfB4wKvoXOIfFFxFMR13TUzKzAzHo2TwPnAUtJ7MPXg9W+DjwZTYXHpK3anwL+Orhi40xgb4vmgqR0WDv0Z0kcG0jsy5eDqxzKgFHA3O6ury1BO+1sYLm7/7LFopQ7Nm3tSyoeGzMrNrPewXQ+MIPEdwovAZcEqx1+XJqP1yXAi8FfXkcv6m+Wu+Cb6YtIfLO+Gvj7qOvpYO0jSHzj/zawrLl+Eu1sLwArgeeBvlHX2kb995P487ieRNvhlW3VTuIKgt8Gx2kJUB51/UexL/cEtS4O/rMNbLH+3wf78h5wYdT1H7YvU0g0vywGFgWPi1Lx2BxhX1Lu2ADjgYVBzUuBHwfzR5D4EFoFPAzkBvPzgterguUjOrpNDVUgIpKmUr2JRkRE2qCAFxFJUwp4EZE0pYAXEUlTCngRkTSlgJe0Z2aNLUYdXGRdOOqomZW2HIFSJJlktb+KSMo74Inu4SIZRWfwkrEsMRb/zy0xHv9cMxsZzC81sxeDgaxeMLNhwfwSM3s8GM/7bTObHLxV3Mx+H4zx/WzQSxEz+14wjvliM3sgot2UDKaAl0yQf1gTzZdaLNvr7icDvwFuCub9GrjL3ccD9wE3B/NvBl5x91NIjB2/LJg/Cvitu48D9gCfD+bfCJwavM/V4eyaSNvUk1XSnplVu3thK/PXAdPcfU0woNVWd+9nZjtIdH2vD+Zvcff+ZlYJDHH32hbvUQo85+6jgtc/BLLd/Wdm9gxQDTwBPOEfjAUu0i10Bi+ZztuY7ojaFtONfPDd1idJjPEyEZjXYsRAkW6hgJdM96UWz3OC6TdIjEwKcBnwajD9AvAdOHTjhl5tvamZxYCh7v4S8EMSQ71+5K8IkTDpjEIyQX5wF51mz7h786WSfcxsMYmz8EuDedcCd5jZ94FK4BvB/OuAWWZ2JYkz9e+QGIGyNXHg3uBDwICbPTEGuEi3URu8ZKygDb7c3XdEXYtIGNREIyKSpnQGLyKSpnQGLyKSphTwIiJpSgEvIpKmFPAiImlKAS8ikqb+DwqRNr0rGLthAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" ] - }, - { - "cell_type": "code", - "metadata": { - "id": "4x8hQP_bg4_g" - }, - "source": [ - "" - ], - "execution_count": null, - "outputs": [] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" } - ] + ], + "source": [ + "start2 = time.time()\n", + "print(\"=================================================\")\n", + "print(\"Collecting training data\")\n", + "starte = time.time()\n", + "data = get_training_data(text,window_size)\n", + "ende = str(datetime.timedelta(seconds = time.time()-starte))\n", + "print(\"It took {} to collect the data\".format(ende))\n", + "print(\"The training data has {} target-context pairs\".format(len(data)))\n", + "print(\"===================================================\")\n", + "\n", + "train(data,batch_size)\n", + "\n", + "end2 = str(datetime.timedelta(seconds = time.time()-start2))\n", + "print(\"Training finished.\")\n", + "print(\"It took {} to finish training the model\".format(end2))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BgQkaYstyj0Q" + }, + "source": [ + "# 1.10 Train on the full dataset (0.5 points)\n", + "\n", + "Now, go back to 1.1 and remove the restriction on the number of sentences in your corpus. Then, reexecute code blocks 1.2, 1.3 and 1.6 (or those relevant if you created additional ones). \n", + "\n", + "* Then, retrain your model on the complete dataset.\n", + "\n", + "* Now, the input weights of the model contain the desired word embeddings! Save them together with the corresponding vocabulary items (Pytorch provides a nice [functionality](https://pytorch.org/tutorials/beginner/saving_loading_models.html) for this)." + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "id": "4x8hQP_bg4_g" + }, + "outputs": [], + "source": [ + "data_full = pd.read_csv('https://raw.githubusercontent.com/SouravDutta91/NNTI-WS2021-NLP-Project/main/data/hindi_hatespeech.tsv',sep='\\t')\n", + "full_data = data_full['text']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Started Preprocessing\")\n", + "cleantext_full = corpus_preprocess(full_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "text_full = tokens(cleantext_full)\n", + "text_full = text_full['text'].apply(lambda x: x.split())\n", + "\n", + "v = list(set(text_full.sum())) \n", + "all_word = list(text_full.sum()) \n", + "\n", + "fullword_index = {word: i for i,word in enumerate(v)}\n", + "fullindex_word = {i: word for i,word in enumerate(v)}\n", + "\n", + "data = get_training_data(text,window_size)\n", + "train(data,batch_size)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "NNTI_final_project_task_1.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 } \ No newline at end of file