Skip to content

Commit

Permalink
examples and test (#259)
Browse files Browse the repository at this point in the history
* examples and test

* fixes

* fixes

* cleanup
  • Loading branch information
sudiptoguha authored Jun 23, 2021
1 parent 1b010f5 commit 1942b88
Show file tree
Hide file tree
Showing 5 changed files with 341 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,20 @@

package com.amazon.randomcutforest;

import static com.amazon.randomcutforest.testutils.ExampleDataSets.generateFan;
import static java.lang.Math.PI;
import static java.lang.Math.cos;
import static java.lang.Math.sin;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.List;

import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;

import com.amazon.randomcutforest.returntypes.DensityOutput;
import com.amazon.randomcutforest.returntypes.Neighbor;

@Tag("functional")
public class DynamicPointSetFunctionalTest {
Expand All @@ -40,108 +49,126 @@ public class DynamicPointSetFunctionalTest {
private static double transitionToBaseProbability;
private static int dataSize;

public double[] rotateClockWise(double[] point, double theta) {
assertTrue(point.length == 2);
static double[] rotateClockWise(double[] point, double theta) {
double[] result = new double[2];
result[0] = cos(theta) * point[0] + sin(theta) * point[1];
result[1] = -sin(theta) * point[0] + cos(theta) * point[1];
return result;
}
/*
* @Test public void movingDensity() { int newDimensions = 2; randomSeed = 123;
*
* RandomCutForest newForest = RandomCutForest.builder() .numberOfTrees(100)
* .sampleSize(256) .dimensions(newDimensions) .randomSeed(randomSeed)
* .windowSize(800) .centerOfMassEnabled(true)
* .storeSequenceIndexesEnabled(true) .build();
*
* double[][] data = generateFan(1000, 3);
*
* double[] queryPoint = new double[]{0.7, 0}; for (int degree = 0; degree <
* 360; degree += 2) { for (int j = 0; j < data.length; j++) {
* newForest.update(rotateClockWise(data[j], 2 * PI * degree / 360)); }
* DensityOutput density = newForest.getSimpleDensity(queryPoint); double value
* = density.getDensity(0.001, 2); if ((degree <= 60) || ((degree >= 120) &&
* (degree <= 180)) || ((degree >= 240) && (degree <= 300))) assertTrue(value <
* 0.8); // the fan is above at 90,210,330
*
* if (((degree >= 75) && (degree <= 105)) || ((degree >= 195) && (degree <=
* 225)) || ((degree >= 315) && (degree <= 345))) assertTrue(value > 0.5); //
* fan is close by //intentionally 0.5 is below 0.8 for a robust test
*
* // Testing for directionality // There can be unclear directionality when the
* blades are right above
*
* double bladeAboveInY = density.getDirectionalDensity(0.001, 2).low[1]; double
* bladeBelowInY = density.getDirectionalDensity(0.001, 2).high[1]; double
* bladesToTheLeft = density.getDirectionalDensity(0.001, 2).high[0]; double
* bladesToTheRight = density.getDirectionalDensity(0.001, 2).low[0];
*
*
* assertEquals(value, bladeAboveInY + bladeBelowInY + bladesToTheLeft +
* bladesToTheRight, 1E-6);
*
* // the tests below have a freedom of 10% of the total value if (((degree >=
* 75) && (degree <= 86)) || ((degree >= 195) && (degree <= 206)) || ((degree >=
* 315) && (degree <= 326))) { assertTrue(bladeAboveInY + 0.1 * value >
* bladeBelowInY); assertTrue(bladeAboveInY + 0.1 * value > bladesToTheRight); }
*
* if (((degree >= 94) && (degree <= 105)) || ((degree >= 214) && (degree <=
* 225)) || ((degree >= 334) && (degree <= 345))) { assertTrue(bladeBelowInY +
* 0.1 * value > bladeAboveInY); assertTrue(bladeBelowInY + 0.1 * value >
* bladesToTheRight); }
*
* if (((degree >= 60) && (degree <= 75)) || ((degree >= 180) && (degree <=
* 195)) || ((degree >= 300) && (degree <= 315))) { assertTrue(bladeAboveInY +
* 0.1 * value > bladesToTheLeft); assertTrue(bladeAboveInY + 0.1 * value >
* bladesToTheRight); }
*
* if (((degree >= 105) && (degree <= 120)) || ((degree >= 225) && (degree <=
* 240)) || (degree >= 345)) { assertTrue(bladeBelowInY + 0.1 * value >
* bladesToTheLeft); assertTrue(bladeBelowInY + 0.1 * value > bladesToTheRight);
* }
*
* // fans are farthest to the left at 30,150 and 270 if (((degree >= 15) &&
* (degree <= 45)) || ((degree >= 135) && (degree <= 165)) || ((degree >= 255)
* && (degree <= 285))) { assertTrue(bladesToTheLeft + 0.1 * value >
* bladeAboveInY + bladeBelowInY + bladesToTheRight); assertTrue(bladeAboveInY +
* bladeBelowInY + 0.1 * value > bladesToTheRight); }
*
* }
*
* }
*
* @Test public void movingNeighbors() { int newDimensions = 2; randomSeed =
* 123;
*
* RandomCutForest newForest = RandomCutForest.builder() .numberOfTrees(100)
* .sampleSize(256) .dimensions(newDimensions) .randomSeed(randomSeed)
* .windowSize(800) .centerOfMassEnabled(true)
* .storeSequenceIndexesEnabled(true) .build();
*
* double[][] data = generateFan(1000, 3);
*
* double[] queryPoint = new double[]{0.7, 0}; for (int degree = 0; degree <
* 360; degree += 2) { for (int j = 0; j < data.length; j++) {
* newForest.update(rotateClockWise(data[j], 2 * PI * degree / 360)); }
* List<Neighbor> ans=newForest.getNearNeighborsInSample(queryPoint,1);
* List<Neighbor>
* closeNeighBors=newForest.getNearNeighborsInSample(queryPoint,0.1); Neighbor
* best = null; if (ans!=null) { best = ans.get(0); for (int j = 1; j <
* ans.size(); j++) { assert (ans.get(j).distance >= best.distance); } }
*
* // fan is away at 30, 150 and 270 if (((degree>15) && (degree<45))|| ((degree
* >= 135) && (degree <= 165)) || ((degree >= 255) && (degree <= 285))) {
* assertTrue(closeNeighBors.size()==0); // no close neighbor
* assertTrue(best.distance>0.3); }
*
* // fan is overhead at 90, 210 and 330 if (((degree>75) && (degree<105))||
* ((degree >= 195) && (degree <= 225)) || ((degree >= 315) && (degree <= 345)))
* { assertTrue(closeNeighBors.size()>0);
* assertEquals(closeNeighBors.get(0).distance,best.distance,1E-10); }
*
* }
*
* }
*/

@Test
public void movingDensity() {
int newDimensions = 2;
randomSeed = 123;

RandomCutForest newForest = RandomCutForest.builder().dimensions(newDimensions).randomSeed(randomSeed)
.timeDecay(1.0 / 800).centerOfMassEnabled(true).storeSequenceIndexesEnabled(true).build();

double[][] data = generateFan(1000, 3);

double[] queryPoint = new double[] { 0.7, 0 };
for (int degree = 0; degree < 360; degree += 2) {
for (int j = 0; j < data.length; j++) {
newForest.update(rotateClockWise(data[j], 2 * PI * degree / 360));
}
DensityOutput density = newForest.getSimpleDensity(queryPoint);
double value = density.getDensity(0.001, 2);
if ((degree <= 60) || ((degree >= 120) && (degree <= 180)) || ((degree >= 240) && (degree <= 300)))
assertTrue(value < 0.8); // the fan is above at 90,210,330

if (((degree >= 75) && (degree <= 105)) || ((degree >= 195) && (degree <= 225))
|| ((degree >= 315) && (degree <= 345)))
assertTrue(value > 0.5);
// fan is close by
// intentionally 0.5 is below 0.8 for a robust test

// Testing for directionality
// There can be unclear directionality when the
// blades are right above

double bladeAboveInY = density.getDirectionalDensity(0.001, 2).low[1];
double bladeBelowInY = density.getDirectionalDensity(0.001, 2).high[1];
double bladesToTheLeft = density.getDirectionalDensity(0.001, 2).high[0];
double bladesToTheRight = density.getDirectionalDensity(0.001, 2).low[0];

assertEquals(value, bladeAboveInY + bladeBelowInY + bladesToTheLeft + bladesToTheRight, 1E-6);

// the tests below have a freedom of 10% of the total value
if (((degree >= 75) && (degree <= 85)) || ((degree >= 195) && (degree <= 205))
|| ((degree >= 315) && (degree <= 325))) {
assertTrue(bladeAboveInY + 0.1 * value > bladeBelowInY);
assertTrue(bladeAboveInY + 0.1 * value > bladesToTheRight);
}

if (((degree >= 95) && (degree <= 105)) || ((degree >= 215) && (degree <= 225))
|| ((degree >= 335) && (degree <= 345))) {
assertTrue(bladeBelowInY + 0.1 * value > bladeAboveInY);
assertTrue(bladeBelowInY + 0.1 * value > bladesToTheRight);
}

if (((degree >= 60) && (degree <= 75)) || ((degree >= 180) && (degree <= 195))
|| ((degree >= 300) && (degree <= 315))) {
assertTrue(bladeAboveInY + 0.1 * value > bladesToTheLeft);
assertTrue(bladeAboveInY + 0.1 * value > bladesToTheRight);
}

if (((degree >= 105) && (degree <= 120)) || ((degree >= 225) && (degree <= 240)) || (degree >= 345)) {
assertTrue(bladeBelowInY + 0.1 * value > bladesToTheLeft);
assertTrue(bladeBelowInY + 0.1 * value > bladesToTheRight);
}

// fans are farthest to the left at 30,150 and 270
if (((degree >= 15) && (degree <= 45)) || ((degree >= 135) && (degree <= 165))
|| ((degree >= 255) && (degree <= 285))) {
assertTrue(bladesToTheLeft + 0.1 * value > bladeAboveInY + bladeBelowInY + bladesToTheRight);
assertTrue(bladeAboveInY + bladeBelowInY + 0.1 * value > bladesToTheRight);
}

}

}

@Test
public void movingNeighbors() {
int newDimensions = 2;
randomSeed = 123;

RandomCutForest newForest = RandomCutForest.builder().dimensions(newDimensions).randomSeed(randomSeed)
.timeDecay(1.0 / 800).centerOfMassEnabled(true).storeSequenceIndexesEnabled(true).build();

double[][] data = generateFan(1000, 3);

double[] queryPoint = new double[] { 0.7, 0 };
for (int degree = 0; degree < 360; degree += 2) {
for (int j = 0; j < data.length; j++) {
newForest.update(rotateClockWise(data[j], 2 * PI * degree / 360));
}
List<Neighbor> ans = newForest.getNearNeighborsInSample(queryPoint, 1);
List<Neighbor> closeNeighBors = newForest.getNearNeighborsInSample(queryPoint, 0.1);
Neighbor best = null;
if (ans != null) {
best = ans.get(0);
for (int j = 1; j < ans.size(); j++) {
assert (ans.get(j).distance >= best.distance);
}
}

// fan is away at 30, 150 and 270
if (((degree > 15) && (degree < 45)) || ((degree >= 135) && (degree <= 165))
|| ((degree >= 255) && (degree <= 285))) {
assertTrue(closeNeighBors.size() == 0); // no close neighbor
assertTrue(best.distance > 0.3);
}

// fan is overhead at 90, 210 and 330
if (((degree > 75) && (degree < 105)) || ((degree >= 195) && (degree <= 225))
|| ((degree >= 315) && (degree <= 345))) {
assertTrue(closeNeighBors.size() > 0);
assertEquals(closeNeighBors.get(0).distance, best.distance, 1E-10);
}

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@

package com.amazon.randomcutforest.examples;

import java.util.Map;
import java.util.TreeMap;

import com.amazon.randomcutforest.examples.dynamicinference.DynamicDensity;
import com.amazon.randomcutforest.examples.dynamicinference.DynamicNearNeighbor;
import com.amazon.randomcutforest.examples.serialization.JsonExample;
import com.amazon.randomcutforest.examples.serialization.ProtostuffExample;

import java.util.Map;
import java.util.TreeMap;

public class Main {

public static final String ARCHIVE_NAME = "randomcutforest-examples-1.0.jar";
Expand All @@ -37,6 +39,8 @@ public Main() {
maxCommandLength = 0;
add(new JsonExample());
add(new ProtostuffExample());
add(new DynamicDensity());
add(new DynamicNearNeighbor());
}

private void add(Example example) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazon.randomcutforest.examples.dynamicinference;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.examples.Example;
import com.amazon.randomcutforest.returntypes.DensityOutput;

import java.io.BufferedWriter;
import java.io.FileWriter;

import static com.amazon.randomcutforest.testutils.ExampleDataSets.generate;
import static com.amazon.randomcutforest.testutils.ExampleDataSets.rotateClockWise;
import static java.lang.Math.PI;

public class DynamicDensity implements Example {

public static void main(String[] args) throws Exception {
new DynamicDensity().run();
}

@Override
public String command() {
return "dynamic_sampling";
}

@Override
public String description() {
return "shows two potential use of dynamic density computations; estimating density as well " +
"as its directional components";
}

/**
* plot the dynamic_density_example using any tool in gnuplot one can plot the
* directions to higher density via do for [i=0:358:2] {plot
* "dynamic_density_example" index (i+1) u 1:2:3:4 w vectors t ""} or the raw
* density at the points via do for [i=0:358:2] {plot "dynamic_density_example"
* index i w p pt 7 palette t ""}
*
* @throws Exception
*/
@Override
public void run() throws Exception {
int newDimensions = 2;
long randomSeed = 123;

RandomCutForest newForest = RandomCutForest.builder().numberOfTrees(100).sampleSize(256)
.dimensions(newDimensions).randomSeed(randomSeed).timeDecay(1.0 / 800).centerOfMassEnabled(true)
.build();
String name = "dynamic_density_example";
BufferedWriter file = new BufferedWriter(new FileWriter(name));
double[][] data = generate(1000);
double[] queryPoint;
for (int degree = 0; degree < 360; degree += 2) {
for (double[] datum : data) {
newForest.update(rotateClockWise(datum, -2 * PI * degree / 360));
}
for (double[] datum : data) {
queryPoint = rotateClockWise(datum, -2 * PI * degree / 360);
DensityOutput density = newForest.getSimpleDensity(queryPoint);
double value = density.getDensity(0.001, 2);
file.append(queryPoint[0] + " " + queryPoint[1] + " " + value + "\n");
}
file.append("\n");
file.append("\n");

for (double x = -0.95; x < 1; x += 0.1) {
for (double y = -0.95; y < 1; y += 0.1) {
DensityOutput density = newForest.getSimpleDensity(new double[] { x, y });
double aboveInY = density.getDirectionalDensity(0.001, 2).low[1];
double belowInY = density.getDirectionalDensity(0.001, 2).high[1];
double toTheLeft = density.getDirectionalDensity(0.001, 2).high[0];
double toTheRight = density.getDirectionalDensity(0.001, 2).low[0];
double len = Math.sqrt(aboveInY * aboveInY + belowInY * belowInY + toTheLeft * toTheLeft
+ toTheRight * toTheRight);
file.append(x + " " + y + " " + ((toTheRight - toTheLeft) * 0.05 / len) + " "
+ ((aboveInY - belowInY) * 0.05 / len) + "\n");
}
}
file.append("\n");
file.append("\n");
}
file.close();
}
}
Loading

0 comments on commit 1942b88

Please sign in to comment.