Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neural network support webgpu #138

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions examples/NeuralNetwork-color-classifier/sketch.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,14 @@ let data = [
{ r: 0, g: 0, b: 253, color: "blue-ish" },
];

let classifer;
let classifier;
let r = 255;
let g = 0;
let b = 0;
let rSlider, gSlider, bSlider;
let label = "training";

function setup() {
createCanvas(640, 240);

// For this example to work across all browsers
// "webgl" or "cpu" needs to be set as the backend
ml5.setBackend("webgl");

rSlider = createSlider(0, 255, 255).position(10, 20);
gSlider = createSlider(0, 255, 0).position(10, 40);
bSlider = createSlider(0, 255, 0).position(10, 60);

function preload() {
// Step 2: set your neural network options
let options = {
task: "classification",
Expand All @@ -45,6 +35,14 @@ function setup() {

// Step 3: initialize your neural network
classifier = ml5.neuralNetwork(options);
}

function setup() {
createCanvas(640, 240);

rSlider = createSlider(0, 255, 255).position(10, 20);
gSlider = createSlider(0, 255, 0).position(10, 40);
bSlider = createSlider(0, 255, 0).position(10, 60);

// Step 4: add data to the neural network
for (let i = 0; i < data.length; i++) {
Expand Down
15 changes: 5 additions & 10 deletions examples/NeuralNetwork-load/sketch.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ let isModelLoaded = false;
function preload() {
// Load the handPose model
handPose = ml5.handPose();
// Set up the neural network
let classifierOptions = {
task: "classification",
};
classifier = ml5.neuralNetwork(classifierOptions);
}

function setup() {
Expand All @@ -26,16 +31,6 @@ function setup() {
video.size(width, height);
video.hide();

// For this example to work across all browsers
// "webgl" or "cpu" needs to be set as the backend
ml5.setBackend("webgl");

// Set up the neural network
let classifierOptions = {
task: "classification",
};
classifier = ml5.neuralNetwork(classifierOptions);

const modelDetails = {
model: "model/model.json",
metadata: "model/model_meta.json",
Expand Down
10 changes: 4 additions & 6 deletions examples/NeuralNetwork-mouse-gesture/sketch.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@ let label = "training";

let start, end;

function setup() {
createCanvas(640, 240);
// For this example to work across all browsers
// "webgl" or "cpu" needs to be set as the backend
ml5.setBackend("webgl");

function preload() {
// Step 2: set your neural network options
let options = {
task: "classification",
Expand All @@ -37,7 +32,10 @@ function setup() {

// Step 3: initialize your neural network
classifier = ml5.neuralNetwork(options);
}

function setup() {
createCanvas(640, 240);
// Step 4: add data to the neural network
for (let i = 0; i < data.length; i++) {
let item = data[i];
Expand Down
17 changes: 6 additions & 11 deletions examples/NeuralNetwork-train-and-save/sketch.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ let trainButton;
function preload() {
// Load the handPose model, we will use the keypoints form handPose to train the neural network
handPose = ml5.handPose();
// Set up the neural network
let classifierOptions = {
task: "classification",
debug: true,
};
classifier = ml5.neuralNetwork(classifierOptions);
}

function setup() {
Expand Down Expand Up @@ -59,17 +65,6 @@ function setup() {
trainButton = createButton("Train and Save Model");
trainButton.mousePressed(train);

// For this example to work across all browsers
// "webgl" or "cpu" needs to be set as the backend
ml5.setBackend("webgl");

// Set up the neural network
let classifierOptions = {
task: "classification",
debug: true,
};
classifier = ml5.neuralNetwork(classifierOptions);

// Start the handPose detection
handPose.detectStart(video, gotHands);
}
Expand Down
16 changes: 11 additions & 5 deletions examples/NeuroEvolution-flappy-bird/sketch.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

let birds = [];
let pipes = [];
let nextBirds = [];

function setup() {
createCanvas(640, 240);
Expand Down Expand Up @@ -64,15 +65,20 @@ function allBirdsDead() {
}

function reproduction() {
let nextBirds = [];
for (let i = 0; i < birds.length; i++) {
let parentA = weightedSelection();
let parentB = weightedSelection();
let child = parentA.crossover(parentB);
child.mutate(0.01);
nextBirds[i] = new Bird(child);
parentA.crossover(parentB, gotCrossOver);
}
}

function gotCrossOver(child) {
child.mutate(0.01);
nextBirds.push(new Bird(child));
if (nextBirds.length == birds.length) {
birds = nextBirds;
nextBirds = [];
}
birds = nextBirds;
}

// Normalize all fitness values
Expand Down
6 changes: 3 additions & 3 deletions examples/NeuroEvolution-sensors/creature.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ class Creature {
}
}

reproduce() {
let brain = this.brain.copy();
async reproduce(callback) {
let brain = await this.brain.copy(this.gotCopy);
brain.mutate(0.1);
return new Creature(this.position.x, this.position.y, brain);
callback(new Creature(this.position.x, this.position.y, brain));
}

eat() {
Expand Down
7 changes: 5 additions & 2 deletions examples/NeuroEvolution-sensors/sketch.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ function draw() {
if (bloops[i].health < 0) {
bloops.splice(i, 1);
} else if (random(1) < 0.001) {
let child = bloops[i].reproduce();
bloops.push(child);
bloops[i].reproduce(gotChild);
}
}
}
Expand All @@ -54,3 +53,7 @@ function draw() {
bloop.show();
}
}

function gotChild(child) {
bloops.push(child);
}
17 changes: 12 additions & 5 deletions examples/NeuroEvolution-steering/sketch.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/

let creatures = [];
let nextCreatures = [];
let timeSlider;
let lifeSpan = 250; // How long should each generation live
let lifeCounter = 0; // Timer for cycle of generation
Expand Down Expand Up @@ -67,15 +68,21 @@ function normalizeFitness() {
}

function reproduction() {
let nextCreatures = [];
for (let i = 0; i < creatures.length; i++) {
let parentA = weightedSelection();
let parentB = weightedSelection();
let child = parentA.crossover(parentB);
child.mutate(0.1);
nextCreatures[i] = new Creature(random(width), random(height), child);
parentA.crossover(parentB, gotChild);
}
}

function gotChild(child) {
child.mutate(0.1);
let childCreature = new Creature(random(width), random(height), child);
nextCreatures.push(childCreature);
if (nextCreatures.length === creatures.length) {
creatures = nextCreatures;
nextCreatures = [];
}
creatures = nextCreatures;
}

function weightedSelection() {
Expand Down
84 changes: 51 additions & 33 deletions src/NeuralNetwork/index.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import * as tf from "@tensorflow/tfjs";
import callCallback from "../utils/callcallback";
import handleArguments from "../utils/handleArguments";
import { imgToPixelArray, isInstanceOfSupportedElement, } from "../utils/imageUtilities";
import {
imgToPixelArray,
isInstanceOfSupportedElement,
} from "../utils/imageUtilities";
import NeuralNetwork from "./NeuralNetwork";
import NeuralNetworkData from "./NeuralNetworkData";

Expand All @@ -22,7 +25,6 @@ const DEFAULTS = {
};
class DiyNeuralNetwork {
constructor(options, callback) {

// Is there a better way to handle a different
// default learning rate for image classification tasks?
if (options.task === "imageClassification") {
Expand Down Expand Up @@ -107,6 +109,7 @@ class DiyNeuralNetwork {
* @return {Promise<this>} - will be awaited by this.ready
*/
async init() {
await tf.ready();
// check if the a static model should be built based on the inputs and output properties
if (this.options.neuroEvolution === true) {
this.createLayersNoTraining();
Expand Down Expand Up @@ -148,16 +151,22 @@ class DiyNeuralNetwork {
/**
* copy
*/
copy() {
const nnCopy = new DiyNeuralNetwork(this.options);
return tf.tidy(() => {
const weights = this.neuralNetwork.model.getWeights();
const weightCopies = [];
for (let i = 0; i < weights.length; i += 1) {
weightCopies[i] = weights[i].clone();
}
nnCopy.neuralNetwork.model.setWeights(weightCopies);
return nnCopy;
copy(callback) {
return new Promise((resolve) => {
const nnCopy = new DiyNeuralNetwork(this.options, () => {
tf.tidy(() => {
const weights = this.neuralNetwork.model.getWeights();
const weightCopies = [];
for (let i = 0; i < weights.length; i += 1) {
weightCopies[i] = weights[i].clone();
}
nnCopy.neuralNetwork.model.setWeights(weightCopies);
if (callback) {
callback(nnCopy);
}
resolve(nnCopy);
});
});
});
}

Expand Down Expand Up @@ -227,11 +236,7 @@ class DiyNeuralNetwork {
async loadDataFromUrl() {
const { dataUrl, inputs, outputs } = this.options;

await this.neuralNetworkData.loadDataFromUrl(
dataUrl,
inputs,
outputs
);
await this.neuralNetworkData.loadDataFromUrl(dataUrl, inputs, outputs);

// once the data are loaded, create the metadata
// and prep the data for training
Expand Down Expand Up @@ -512,7 +517,10 @@ class DiyNeuralNetwork {
finishedTrainingCb = optionsOrCallback;
}

return callCallback(this.trainInternal(options, whileTrainingCb), finishedTrainingCb);
return callCallback(
this.trainInternal(options, whileTrainingCb),
finishedTrainingCb
);
}

/**
Expand Down Expand Up @@ -579,9 +587,7 @@ class DiyNeuralNetwork {
// then use those to create your architecture
if (!this.neuralNetwork.isLayered) {
// TODO: don't update this.options.layers - Linda
this.options.layers = this.createNetworkLayers(
this.options.layers
);
this.options.layers = this.createNetworkLayers(this.options.layers);
}

// if the model does not have any layers defined yet
Expand Down Expand Up @@ -1143,7 +1149,10 @@ class DiyNeuralNetwork {
*/
saveData(name, callback) {
const args = handleArguments(name, callback);
return callCallback(this.neuralNetworkData.saveData(args.name), args.callback);
return callCallback(
this.neuralNetworkData.saveData(args.name),
args.callback
);
}

/**
Expand Down Expand Up @@ -1175,13 +1184,16 @@ class DiyNeuralNetwork {
*/
async save(name, callback) {
const args = handleArguments(name, callback);
const modelName = args.string || 'model';
const modelName = args.string || "model";

// save the model
return callCallback(Promise.all([
this.neuralNetwork.save(modelName),
this.neuralNetworkData.saveMeta(modelName)
]), args.callback);
return callCallback(
Promise.all([
this.neuralNetwork.save(modelName),
this.neuralNetworkData.saveMeta(modelName),
]),
args.callback
);
}

/**
Expand All @@ -1193,10 +1205,13 @@ class DiyNeuralNetwork {
* @return {Promise<void[]>}
*/
async load(filesOrPath, callback) {
return callCallback(Promise.all([
this.neuralNetwork.load(filesOrPath),
this.neuralNetworkData.loadMeta(filesOrPath)
]), callback);
return callCallback(
Promise.all([
this.neuralNetwork.load(filesOrPath),
this.neuralNetworkData.loadMeta(filesOrPath),
]),
callback
);
}

/**
Expand Down Expand Up @@ -1227,9 +1242,12 @@ class DiyNeuralNetwork {
* @param {*} other
*/

crossover(other) {
const nnCopy = this.copy();
async crossover(other, callback) {
const nnCopy = await this.copy();
nnCopy.neuralNetwork.crossover(other.neuralNetwork);
if (callback) {
callback(nnCopy);
}
return nnCopy;
}
}
Expand Down
Loading