Skip to content

Commit

Permalink
Splat worker generates indicies by writing to VB (#258)
Browse files Browse the repository at this point in the history
Co-authored-by: Martin Valigursky <[email protected]>
  • Loading branch information
mvaligursky and Martin Valigursky committed Oct 20, 2023
1 parent 6748515 commit d69403f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 50 deletions.
54 changes: 31 additions & 23 deletions src/splat/sort-worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ function SortWorker() {
const epsilon = 0.0001;

let data: Float32Array;
let stride: number;
let centers: Float32Array;
let cameraPosition: Vec3;
let cameraDirection: Vec3;
let isWebGPU: boolean;

const lastCameraPosition = { x: 0, y: 0, z: 0 };
const lastCameraDirection = { x: 0, y: 0, z: 0 };
Expand All @@ -21,7 +22,7 @@ function SortWorker() {
let target: Float32Array;

const update = () => {
if (!data || !stride || !cameraPosition || !cameraDirection) return;
if (!centers || !data || !cameraPosition || !cameraDirection) return;

const px = cameraPosition.x;
const py = cameraPosition.y;
Expand All @@ -47,48 +48,52 @@ function SortWorker() {
lastCameraDirection.y = dy;
lastCameraDirection.z = dz;

const numVertices = data.length / stride;
const numVertices = centers.length / 3;

// create distance buffer
if (orderBuffer?.length !== numVertices) {
orderBuffer = new BigUint64Array(numVertices);
orderBuffer32 = new Uint32Array(orderBuffer.buffer);
target = new Float32Array(numVertices * stride);
target = new Float32Array(numVertices);
}

const strideVertices = numVertices * stride;

// calc min/max
let minDist = (data[0] - px) * dx + (data[1] - py) * dy + (data[2] - pz) * dz;
let minDist = (centers[0] - px) * dx + (centers[1] - py) * dy + (centers[2] - pz) * dz;
let maxDist = minDist;
for (let i = stride; i < strideVertices; i += stride) {
const d = (data[i + 0] - px) * dx +
(data[i + 1] - py) * dy +
(data[i + 2] - pz) * dz;
for (let i = 1; i < numVertices; i++) {
const istride = i * 3;
const d = (centers[istride + 0] - px) * dx +
(centers[istride + 1] - py) * dy +
(centers[istride + 2] - pz) * dz;
minDist = Math.min(minDist, d);
maxDist = Math.max(maxDist, d);
}

// generate per vertex distance to camera
const range = maxDist - minDist;
for (let i = 0; i < numVertices; ++i) {
const istride = i * stride;
const d = (data[istride + 0] - px) * dx +
(data[istride + 1] - py) * dy +
(data[istride + 2] - pz) * dz;
const istride = i * 3;
const d = (centers[istride + 0] - px) * dx +
(centers[istride + 1] - py) * dy +
(centers[istride + 2] - pz) * dz;
orderBuffer32[i * 2 + 0] = i;
orderBuffer32[i * 2 + 1] = Math.floor((d - minDist) / range * (2 ** 32));
orderBuffer32[i * 2] = i;
}

// sort indices
orderBuffer.sort();

// order the splat data
for (let i = 0; i < numVertices; ++i) {
const ti = i * stride;
const si = orderBuffer32[i * 2] * stride;
for (let j = 0; j < stride; ++j) {
target[ti + j] = data[si + j];
if (isWebGPU) {
const target32 = new Uint32Array(target.buffer);
for (let i = 0; i < numVertices; ++i) {
const index = orderBuffer32[i * 2];
target32[i] = index;
}
} else {
for (let i = 0; i < numVertices; ++i) {
const index = orderBuffer32[i * 2];
target[i] = index + 0.2;
}
}

Expand All @@ -109,8 +114,11 @@ function SortWorker() {
if (message.data.data) {
data = new Float32Array(message.data.data);
}
if (message.data.stride) {
stride = message.data.stride;
if (message.data.centers) {
centers = new Float32Array(message.data.centers);
}
if (message.data.isWebGPU) {
isWebGPU = message.data.isWebGPU;
}
if (message.data.cameraPosition) cameraPosition = message.data.cameraPosition;
if (message.data.cameraDirection) cameraDirection = message.data.cameraDirection;
Expand Down
40 changes: 13 additions & 27 deletions src/splat/splat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import {
Mesh,
Vec3,
createBox,
SEMANTIC_ATTR11,
SEMANTIC_ATTR13,
TYPE_FLOAT32,
VertexFormat,
Expand Down Expand Up @@ -294,7 +293,6 @@ class Splat {
return texture;
}


create(splatData: SplatData, options: any) {
const x = splatData.getProp('x');
const y = splatData.getProp('y');
Expand All @@ -304,34 +302,18 @@ class Splat {
return;
}

const stride = 4;

const textureSize = this.evalTextureSize(splatData.numSplats);
const colorTexture = this.createColorTexture(splatData, textureSize);
const scaleTexture = this.createScaleTexture(splatData, textureSize, this.getTextureFormat(false));
const rotationTexture = this.createRotationTexture(splatData, textureSize, this.getTextureFormat(false));
const centerTexture = this.createCenterTexture(splatData, textureSize, this.getTextureFormat(false));

// position.xyz, rotation.xyz
const floatData = new Float32Array(splatData.numSplats * stride);
const uint32Data = new Uint32Array(floatData.buffer);

const isWebGPU = this.device.isWebGPU;

// centers - constant buffer that is sent to the worker
const centers = new Float32Array(splatData.numSplats * 3);
for (let i = 0; i < splatData.numSplats; ++i) {
const j = i;

// positions
floatData[i * stride + 0] = x[j];
floatData[i * stride + 1] = y[j];
floatData[i * stride + 2] = z[j];

// index
if (isWebGPU) {
uint32Data[i * stride + 3] = i;
} else {
floatData[i * stride + 3] = i + 0.2;
}
centers[i * 3 + 0] = x[i];
centers[i * 3 + 1] = y[i];
centers[i * 3 + 2] = z[i];
}

this.material.setParameter('splatColor', colorTexture);
Expand All @@ -341,10 +323,12 @@ class Splat {
this.material.setParameter('tex_params', new Float32Array([textureSize.x, textureSize.y, 1 / textureSize.x, 1 / textureSize.y]));

// create instance data
const isWebGPU = this.device.isWebGPU;
const vertexFormat = new VertexFormat(this.device, [
{ semantic: SEMANTIC_ATTR11, components: 3, type: TYPE_FLOAT32 },
{ semantic: SEMANTIC_ATTR13, components: 1, type: isWebGPU ? TYPE_UINT32 : TYPE_FLOAT32 }
]);

const floatData = new Float32Array(splatData.numSplats);
const vertexBuffer = new VertexBuffer(this.device, vertexFormat, splatData.numSplats, BUFFER_DYNAMIC, floatData.buffer);

this.meshInstance = new MeshInstance(this.quadMesh, this.material);
Expand All @@ -362,7 +346,7 @@ class Splat {
sortWorker.onmessage = (message: any) => {
const data = message.data.data;

// copy source data
// copy output data to VB
floatData.set(new Float32Array(data));

// send the memory buffer back to worker
Expand All @@ -379,10 +363,12 @@ class Splat {

// send the initial buffer to worker
const buf = vertexBuffer.storage.slice(0);
const centerBuf = centers.buffer.slice(0);
sortWorker.postMessage({
data: buf,
stride: stride
}, [buf]);
centers: centerBuf,
isWebGPU: isWebGPU
}, [buf, centerBuf]);

const viewport = [this.device.width, this.device.height];

Expand Down

0 comments on commit d69403f

Please sign in to comment.