diff --git a/karma.conf.js b/karma.conf.js index c5ba1e98d..e5ed19a8c 100644 --- a/karma.conf.js +++ b/karma.conf.js @@ -2,17 +2,42 @@ module.exports = (config) => { config.set({ frameworks: ['jasmine'], files: [ + 'src/index.js', 'src/**/*_test.js', ], preprocessors: { - 'src/**/*_test.js': ['webpack'], + 'src/index.js': ['webpack'], }, webpack: { - // karma watches the test entry points - // (you don't need to specify the entry option) - // webpack watches dependencies - - // webpack configuration + // TODO: This is duplication of the webpack.common.babel.js file, but they + // use different import syntaxes so it's not easy to just require it here. + // Maybe this could be put into a JSON file, but the include in the module + // rules is dynamic. + entry: ['babel-polyfill', './src/index.js'], + output: { + libraryTarget: 'umd', + filename: 'ml5.js', + library: 'ml5', + }, + module: { + rules: [ + { + enforce: 'pre', + test: /\.js$/, + exclude: /node_modules/, + loader: 'eslint-loader', + }, + { + test: /\.js$/, + loader: 'babel-loader', + include: require('path').resolve(__dirname, 'src'), + }, + ], + }, + // Don't minify the webpack build for better stack traces + optimization: { + minimize: false, + }, }, webpackMiddleware: { noInfo: true, diff --git a/package.json b/package.json index ef7164818..00fdd261b 100644 --- a/package.json +++ b/package.json @@ -13,6 +13,7 @@ "start": "webpack-dev-server --open --config webpack.dev.babel.js", "build": "webpack --config webpack.prod.babel.js", "test": "./node_modules/karma/bin/karma start karma.conf.js", + "test:single": "./node_modules/karma/bin/karma start karma.conf.js --single-run", "test-travis": "./scripts/test-travis.sh" }, "repository": { diff --git a/src/ImageClassifier/index_test.js b/src/ImageClassifier/index_test.js index f6242fd23..ec7e7ec65 100644 --- a/src/ImageClassifier/index_test.js +++ b/src/ImageClassifier/index_test.js @@ -1,6 +1,6 @@ /* eslint new-cap: 0 */ -import * as ImageClassifier from './index'; +const { ImageClassifier } = ml5; const DEFAULTS = { learningRate: 0.0001, diff --git a/src/Word2vec/index.js b/src/Word2vec/index.js index f88b373ed..aeb5717f0 100644 --- a/src/Word2vec/index.js +++ b/src/Word2vec/index.js @@ -34,20 +34,30 @@ class Word2Vec { }); } + dispose() { + Object.values(this.model).forEach(x => x.dispose()); + } + add(inputs, max = 1) { - const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD'); - return Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max); + return tf.tidy(() => { + const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD'); + return Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max); + }); } subtract(inputs, max = 1) { - const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT'); - return Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max); + return tf.tidy(() => { + const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT'); + return Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max); + }); } average(inputs, max = 1) { - const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD'); - const avg = tf.div(sum, tf.tensor(inputs.length)); - return Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max); + return tf.tidy(() => { + const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD'); + const avg = tf.div(sum, tf.tensor(inputs.length)); + return Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max); + }); } nearest(input, max = 10) { @@ -64,34 +74,36 @@ class Word2Vec { } static addOrSubtract(model, values, operation) { - const vectors = []; - const notFound = []; - if (values.length < 2) { - throw new Error('Invalid input, must be passed more than 1 value'); - } - values.forEach((value) => { - const vector = model[value]; - if (!vector) { - notFound.push(value); - } else { - vectors.push(vector); + return tf.tidy(() => { + const vectors = []; + const notFound = []; + if (values.length < 2) { + throw new Error('Invalid input, must be passed more than 1 value'); } - }); + values.forEach((value) => { + const vector = model[value]; + if (!vector) { + notFound.push(value); + } else { + vectors.push(vector); + } + }); - if (notFound.length > 0) { - throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`); - } - let result = vectors[0]; - if (operation === 'ADD') { - for (let i = 1; i < vectors.length; i += 1) { - result = tf.add(result, vectors[i]); + if (notFound.length > 0) { + throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`); } - } else { - for (let i = 1; i < vectors.length; i += 1) { - result = tf.sub(result, vectors[i]); + let result = vectors[0]; + if (operation === 'ADD') { + for (let i = 1; i < vectors.length; i += 1) { + result = tf.add(result, vectors[i]); + } + } else { + for (let i = 1; i < vectors.length; i += 1) { + result = tf.sub(result, vectors[i]); + } } - } - return result; + return result; + }); } static nearest(model, input, start, max) { diff --git a/src/Word2vec/index_test.js b/src/Word2vec/index_test.js index b295ebc16..211b31a6e 100644 --- a/src/Word2vec/index_test.js +++ b/src/Word2vec/index_test.js @@ -1,19 +1,96 @@ -// import Word2Vec from './index'; - -// const URL = 'https://raw.githubusercontent.com/ml5js/ml5-examples/master/p5js/07_Word2Vec/data/wordvecs1000.json'; - -// describe('initialize word2vec', () => { -// let word2vec; -// beforeAll((done) => { -// // word2vec = new Word2Vec(URL); -// done(); -// }); - -// // it('creates a new instance', (done) => { -// // expect(word2vec).toEqual(jasmine.objectContaining({ -// // ready: true, -// // modelSize: 1, -// // })); -// // done(); -// // }); -// }); +const { tf, word2vec } = ml5; + +const URL = 'https://raw.githubusercontent.com/ml5js/ml5-examples/master/p5js/Word2Vec/data/wordvecs1000.json'; + +describe('word2vec', () => { + let word2vecInstance; + let numTensorsBeforeAll; + let numTensorsBeforeEach; + beforeAll((done) => { + numTensorsBeforeAll = tf.memory().numTensors; + word2vecInstance = word2vec(URL, done); + }); + + afterAll(() => { + word2vecInstance.dispose(); + let numTensorsAfterAll = tf.memory().numTensors; + if(numTensorsBeforeAll !== numTensorsAfterAll) { + throw new Error(`Leaking Tensors (${numTensorsAfterAll} vs ${numTensorsBeforeAll})`); + } + }); + + beforeEach(() => { + numTensorsBeforeEach = tf.memory().numTensors; + }); + + afterEach(() => { + let numTensorsAfterEach = tf.memory().numTensors; + if(numTensorsBeforeEach !== numTensorsAfterEach) { + throw new Error(`Leaking Tensors (${numTensorsAfterEach} vs ${numTensorsBeforeEach})`); + } + }); + + it('creates a new instance', () => { + expect(word2vecInstance).toEqual(jasmine.objectContaining({ + ready: true, + modelSize: 1, + })); + }); + + describe('getRandomWord', () => { + it('returns a word', () => { + let word = word2vecInstance.getRandomWord(); + expect(typeof word).toEqual('string'); + }); + }); + + describe('nearest', () => { + it('returns a sorted array of nearest words', () => { + for(let i = 0; i < 100; i++) { + let word = word2vecInstance.getRandomWord(); + let nearest = word2vecInstance.nearest(word); + let currentDistance = 0; + for(let { word, distance: nextDistance } of nearest) { + expect(typeof word).toEqual('string'); + expect(nextDistance).toBeGreaterThan(currentDistance); + currentDistance = nextDistance; + } + } + }); + + it('returns a list of the right length', () => { + for(let i = 0; i < 100; i++) { + let word = word2vecInstance.getRandomWord(); + let nearest = word2vecInstance.nearest(word, i); + expect(nearest.length).toEqual(i); + } + }); + }); + + describe('add', () => { + it('returns a value', () => { + let word1 = word2vecInstance.getRandomWord(); + let word2 = word2vecInstance.getRandomWord(); + let sum = word2vecInstance.subtract([word1, word2]); + expect(sum[0].distance).toBeGreaterThan(0); + }) + }); + + describe('subtract', () => { + it('returns a value', () => { + let word1 = word2vecInstance.getRandomWord(); + let word2 = word2vecInstance.getRandomWord(); + let sum = word2vecInstance.subtract([word1, word2]); + expect(sum[0].distance).toBeGreaterThan(0); + }) + }); + + describe('average', () => { + it('returns a value', () => { + let word1 = word2vecInstance.getRandomWord(); + let word2 = word2vecInstance.getRandomWord(); + let average = word2vecInstance.average([word1, word2]); + expect(average[0].distance).toBeGreaterThan(0); + }); + }); +});