Skip to content

Commit

Permalink
Merge pull request #173 from meiamsome/testing/word2vec
Browse files Browse the repository at this point in the history
Word2Vec tests
  • Loading branch information
cvalenzuela authored Jun 27, 2018
2 parents 8431d99 + d921d83 commit 82843f8
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 57 deletions.
37 changes: 31 additions & 6 deletions karma.conf.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,42 @@ module.exports = (config) => {
config.set({
frameworks: ['jasmine'],
files: [
'src/index.js',
'src/**/*_test.js',
],
preprocessors: {
'src/**/*_test.js': ['webpack'],
'src/index.js': ['webpack'],
},
webpack: {
// karma watches the test entry points
// (you don't need to specify the entry option)
// webpack watches dependencies

// webpack configuration
// TODO: This is duplication of the webpack.common.babel.js file, but they
// use different import syntaxes so it's not easy to just require it here.
// Maybe this could be put into a JSON file, but the include in the module
// rules is dynamic.
entry: ['babel-polyfill', './src/index.js'],
output: {
libraryTarget: 'umd',
filename: 'ml5.js',
library: 'ml5',
},
module: {
rules: [
{
enforce: 'pre',
test: /\.js$/,
exclude: /node_modules/,
loader: 'eslint-loader',
},
{
test: /\.js$/,
loader: 'babel-loader',
include: require('path').resolve(__dirname, 'src'),
},
],
},
// Don't minify the webpack build for better stack traces
optimization: {
minimize: false,
},
},
webpackMiddleware: {
noInfo: true,
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"start": "webpack-dev-server --open --config webpack.dev.babel.js",
"build": "webpack --config webpack.prod.babel.js",
"test": "./node_modules/karma/bin/karma start karma.conf.js",
"test:single": "./node_modules/karma/bin/karma start karma.conf.js --single-run",
"test-travis": "./scripts/test-travis.sh"
},
"repository": {
Expand Down
2 changes: 1 addition & 1 deletion src/ImageClassifier/index_test.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint new-cap: 0 */

import * as ImageClassifier from './index';
const { ImageClassifier } = ml5;

const DEFAULTS = {
learningRate: 0.0001,
Expand Down
74 changes: 43 additions & 31 deletions src/Word2vec/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,30 @@ class Word2Vec {
});
}

dispose() {
Object.values(this.model).forEach(x => x.dispose());
}

add(inputs, max = 1) {
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
return Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max);
return tf.tidy(() => {
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
return Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max);
});
}

subtract(inputs, max = 1) {
const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT');
return Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max);
return tf.tidy(() => {
const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT');
return Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max);
});
}

average(inputs, max = 1) {
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
const avg = tf.div(sum, tf.tensor(inputs.length));
return Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max);
return tf.tidy(() => {
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
const avg = tf.div(sum, tf.tensor(inputs.length));
return Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max);
});
}

nearest(input, max = 10) {
Expand All @@ -64,34 +74,36 @@ class Word2Vec {
}

static addOrSubtract(model, values, operation) {
const vectors = [];
const notFound = [];
if (values.length < 2) {
throw new Error('Invalid input, must be passed more than 1 value');
}
values.forEach((value) => {
const vector = model[value];
if (!vector) {
notFound.push(value);
} else {
vectors.push(vector);
return tf.tidy(() => {
const vectors = [];
const notFound = [];
if (values.length < 2) {
throw new Error('Invalid input, must be passed more than 1 value');
}
});
values.forEach((value) => {
const vector = model[value];
if (!vector) {
notFound.push(value);
} else {
vectors.push(vector);
}
});

if (notFound.length > 0) {
throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`);
}
let result = vectors[0];
if (operation === 'ADD') {
for (let i = 1; i < vectors.length; i += 1) {
result = tf.add(result, vectors[i]);
if (notFound.length > 0) {
throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`);
}
} else {
for (let i = 1; i < vectors.length; i += 1) {
result = tf.sub(result, vectors[i]);
let result = vectors[0];
if (operation === 'ADD') {
for (let i = 1; i < vectors.length; i += 1) {
result = tf.add(result, vectors[i]);
}
} else {
for (let i = 1; i < vectors.length; i += 1) {
result = tf.sub(result, vectors[i]);
}
}
}
return result;
return result;
});
}

static nearest(model, input, start, max) {
Expand Down
115 changes: 96 additions & 19 deletions src/Word2vec/index_test.js
Original file line number Diff line number Diff line change
@@ -1,19 +1,96 @@
// import Word2Vec from './index';

// const URL = 'https://raw.githubusercontent.com/ml5js/ml5-examples/master/p5js/07_Word2Vec/data/wordvecs1000.json';

// describe('initialize word2vec', () => {
// let word2vec;
// beforeAll((done) => {
// // word2vec = new Word2Vec(URL);
// done();
// });

// // it('creates a new instance', (done) => {
// // expect(word2vec).toEqual(jasmine.objectContaining({
// // ready: true,
// // modelSize: 1,
// // }));
// // done();
// // });
// });
const { tf, word2vec } = ml5;

const URL = 'https://raw.githubusercontent.com/ml5js/ml5-examples/master/p5js/Word2Vec/data/wordvecs1000.json';

describe('word2vec', () => {
let word2vecInstance;
let numTensorsBeforeAll;
let numTensorsBeforeEach;
beforeAll((done) => {
numTensorsBeforeAll = tf.memory().numTensors;
word2vecInstance = word2vec(URL, done);
});

afterAll(() => {
word2vecInstance.dispose();
let numTensorsAfterAll = tf.memory().numTensors;
if(numTensorsBeforeAll !== numTensorsAfterAll) {
throw new Error(`Leaking Tensors (${numTensorsAfterAll} vs ${numTensorsBeforeAll})`);
}
});

beforeEach(() => {
numTensorsBeforeEach = tf.memory().numTensors;
});

afterEach(() => {
let numTensorsAfterEach = tf.memory().numTensors;
if(numTensorsBeforeEach !== numTensorsAfterEach) {
throw new Error(`Leaking Tensors (${numTensorsAfterEach} vs ${numTensorsBeforeEach})`);
}
});

it('creates a new instance', () => {
expect(word2vecInstance).toEqual(jasmine.objectContaining({
ready: true,
modelSize: 1,
}));
});

describe('getRandomWord', () => {
it('returns a word', () => {
let word = word2vecInstance.getRandomWord();
expect(typeof word).toEqual('string');
});
});

describe('nearest', () => {
it('returns a sorted array of nearest words', () => {
for(let i = 0; i < 100; i++) {
let word = word2vecInstance.getRandomWord();
let nearest = word2vecInstance.nearest(word);
let currentDistance = 0;
for(let { word, distance: nextDistance } of nearest) {
expect(typeof word).toEqual('string');
expect(nextDistance).toBeGreaterThan(currentDistance);
currentDistance = nextDistance;
}
}
});

it('returns a list of the right length', () => {
for(let i = 0; i < 100; i++) {
let word = word2vecInstance.getRandomWord();
let nearest = word2vecInstance.nearest(word, i);
expect(nearest.length).toEqual(i);
}
});
});

describe('add', () => {
it('returns a value', () => {
let word1 = word2vecInstance.getRandomWord();
let word2 = word2vecInstance.getRandomWord();
let sum = word2vecInstance.subtract([word1, word2]);
expect(sum[0].distance).toBeGreaterThan(0);
})
});

describe('subtract', () => {
it('returns a value', () => {
let word1 = word2vecInstance.getRandomWord();
let word2 = word2vecInstance.getRandomWord();
let sum = word2vecInstance.subtract([word1, word2]);
expect(sum[0].distance).toBeGreaterThan(0);
})
});

describe('average', () => {
it('returns a value', () => {
let word1 = word2vecInstance.getRandomWord();
let word2 = word2vecInstance.getRandomWord();
let average = word2vecInstance.average([word1, word2]);
expect(average[0].distance).toBeGreaterThan(0);
});
});
});

0 comments on commit 82843f8

Please sign in to comment.