diff --git a/src/splat/sort-worker.ts b/src/splat/sort-worker.ts index b30c6ac..677c28d 100644 --- a/src/splat/sort-worker.ts +++ b/src/splat/sort-worker.ts @@ -9,9 +9,10 @@ function SortWorker() { const epsilon = 0.0001; let data: Float32Array; - let stride: number; + let centers: Float32Array; let cameraPosition: Vec3; let cameraDirection: Vec3; + let isWebGPU: boolean; const lastCameraPosition = { x: 0, y: 0, z: 0 }; const lastCameraDirection = { x: 0, y: 0, z: 0 }; @@ -21,7 +22,7 @@ function SortWorker() { let target: Float32Array; const update = () => { - if (!data || !stride || !cameraPosition || !cameraDirection) return; + if (!centers || !data || !cameraPosition || !cameraDirection) return; const px = cameraPosition.x; const py = cameraPosition.y; @@ -47,24 +48,23 @@ function SortWorker() { lastCameraDirection.y = dy; lastCameraDirection.z = dz; - const numVertices = data.length / stride; + const numVertices = centers.length / 3; // create distance buffer if (orderBuffer?.length !== numVertices) { orderBuffer = new BigUint64Array(numVertices); orderBuffer32 = new Uint32Array(orderBuffer.buffer); - target = new Float32Array(numVertices * stride); + target = new Float32Array(numVertices); } - const strideVertices = numVertices * stride; - // calc min/max - let minDist = (data[0] - px) * dx + (data[1] - py) * dy + (data[2] - pz) * dz; + let minDist = (centers[0] - px) * dx + (centers[1] - py) * dy + (centers[2] - pz) * dz; let maxDist = minDist; - for (let i = stride; i < strideVertices; i += stride) { - const d = (data[i + 0] - px) * dx + - (data[i + 1] - py) * dy + - (data[i + 2] - pz) * dz; + for (let i = 1; i < numVertices; i++) { + const istride = i * 3; + const d = (centers[istride + 0] - px) * dx + + (centers[istride + 1] - py) * dy + + (centers[istride + 2] - pz) * dz; minDist = Math.min(minDist, d); maxDist = Math.max(maxDist, d); } @@ -72,23 +72,28 @@ function SortWorker() { // generate per vertex distance to camera const range = maxDist - minDist; for (let i = 0; i < numVertices; ++i) { - const istride = i * stride; - const d = (data[istride + 0] - px) * dx + - (data[istride + 1] - py) * dy + - (data[istride + 2] - pz) * dz; + const istride = i * 3; + const d = (centers[istride + 0] - px) * dx + + (centers[istride + 1] - py) * dy + + (centers[istride + 2] - pz) * dz; + orderBuffer32[i * 2 + 0] = i; orderBuffer32[i * 2 + 1] = Math.floor((d - minDist) / range * (2 ** 32)); - orderBuffer32[i * 2] = i; } // sort indices orderBuffer.sort(); // order the splat data - for (let i = 0; i < numVertices; ++i) { - const ti = i * stride; - const si = orderBuffer32[i * 2] * stride; - for (let j = 0; j < stride; ++j) { - target[ti + j] = data[si + j]; + if (isWebGPU) { + const target32 = new Uint32Array(target.buffer); + for (let i = 0; i < numVertices; ++i) { + const index = orderBuffer32[i * 2]; + target32[i] = index; + } + } else { + for (let i = 0; i < numVertices; ++i) { + const index = orderBuffer32[i * 2]; + target[i] = index + 0.2; } } @@ -109,8 +114,11 @@ function SortWorker() { if (message.data.data) { data = new Float32Array(message.data.data); } - if (message.data.stride) { - stride = message.data.stride; + if (message.data.centers) { + centers = new Float32Array(message.data.centers); + } + if (message.data.isWebGPU) { + isWebGPU = message.data.isWebGPU; } if (message.data.cameraPosition) cameraPosition = message.data.cameraPosition; if (message.data.cameraDirection) cameraDirection = message.data.cameraDirection; diff --git a/src/splat/splat.ts b/src/splat/splat.ts index 7e2771e..3b2402a 100644 --- a/src/splat/splat.ts +++ b/src/splat/splat.ts @@ -12,7 +12,6 @@ import { Mesh, Vec3, createBox, - SEMANTIC_ATTR11, SEMANTIC_ATTR13, TYPE_FLOAT32, VertexFormat, @@ -294,7 +293,6 @@ class Splat { return texture; } - create(splatData: SplatData, options: any) { const x = splatData.getProp('x'); const y = splatData.getProp('y'); @@ -304,34 +302,18 @@ class Splat { return; } - const stride = 4; - const textureSize = this.evalTextureSize(splatData.numSplats); const colorTexture = this.createColorTexture(splatData, textureSize); const scaleTexture = this.createScaleTexture(splatData, textureSize, this.getTextureFormat(false)); const rotationTexture = this.createRotationTexture(splatData, textureSize, this.getTextureFormat(false)); const centerTexture = this.createCenterTexture(splatData, textureSize, this.getTextureFormat(false)); - // position.xyz, rotation.xyz - const floatData = new Float32Array(splatData.numSplats * stride); - const uint32Data = new Uint32Array(floatData.buffer); - - const isWebGPU = this.device.isWebGPU; - + // centers - constant buffer that is sent to the worker + const centers = new Float32Array(splatData.numSplats * 3); for (let i = 0; i < splatData.numSplats; ++i) { - const j = i; - - // positions - floatData[i * stride + 0] = x[j]; - floatData[i * stride + 1] = y[j]; - floatData[i * stride + 2] = z[j]; - - // index - if (isWebGPU) { - uint32Data[i * stride + 3] = i; - } else { - floatData[i * stride + 3] = i + 0.2; - } + centers[i * 3 + 0] = x[i]; + centers[i * 3 + 1] = y[i]; + centers[i * 3 + 2] = z[i]; } this.material.setParameter('splatColor', colorTexture); @@ -341,10 +323,12 @@ class Splat { this.material.setParameter('tex_params', new Float32Array([textureSize.x, textureSize.y, 1 / textureSize.x, 1 / textureSize.y])); // create instance data + const isWebGPU = this.device.isWebGPU; const vertexFormat = new VertexFormat(this.device, [ - { semantic: SEMANTIC_ATTR11, components: 3, type: TYPE_FLOAT32 }, { semantic: SEMANTIC_ATTR13, components: 1, type: isWebGPU ? TYPE_UINT32 : TYPE_FLOAT32 } ]); + + const floatData = new Float32Array(splatData.numSplats); const vertexBuffer = new VertexBuffer(this.device, vertexFormat, splatData.numSplats, BUFFER_DYNAMIC, floatData.buffer); this.meshInstance = new MeshInstance(this.quadMesh, this.material); @@ -362,7 +346,7 @@ class Splat { sortWorker.onmessage = (message: any) => { const data = message.data.data; - // copy source data + // copy output data to VB floatData.set(new Float32Array(data)); // send the memory buffer back to worker @@ -379,10 +363,12 @@ class Splat { // send the initial buffer to worker const buf = vertexBuffer.storage.slice(0); + const centerBuf = centers.buffer.slice(0); sortWorker.postMessage({ data: buf, - stride: stride - }, [buf]); + centers: centerBuf, + isWebGPU: isWebGPU + }, [buf, centerBuf]); const viewport = [this.device.width, this.device.height];