Skip to content

Commit

Permalink
Store splat instances, support adding splat property (#264)
Browse files Browse the repository at this point in the history
* store splat instances

* addProp

* debug shader

* whitespace
  • Loading branch information
slimbuck committed Oct 26, 2023
1 parent f61cb7d commit 44cba90
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 70 deletions.
5 changes: 5 additions & 0 deletions src/splat/sort-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ class SortManager {
};
}

destroy() {
this.worker.terminate();
this.worker = null;
}

sort(vertexBuffer: VertexBuffer, centers: Float32Array, intIndices: boolean, updatedCallback?: () => void) {
this.vertexBuffer = vertexBuffer;
this.updatedCallback = updatedCallback;
Expand Down
10 changes: 10 additions & 0 deletions src/splat/splat-data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,16 @@ class SplatData {
return this.vertexElement.properties.find((property: any) => property.name === name && property.storage)?.storage;
}

// add a new property
addProp(name: string, storage: Float32Array) {
this.vertexElement.properties.push({
type: 'float',
name,
storage,
byteSize: 4
});
}

// calculate scene aabb taking into account splat size
calcAabb(result: BoundingBox) {
const x = this.getProp('x');
Expand Down
65 changes: 31 additions & 34 deletions src/splat/splat-material.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,33 +159,40 @@ const splatVS = `
return;
}
vec3 splat_cova;
vec3 splat_covb;
vec3 scale = getScale();
vec3 rotation = getRotation();
computeCov3d(mat3(matrix_model) * quatToMat3(rotation), scale, splat_cova, splat_covb);
mat3 Vrk = mat3(
splat_cova.x, splat_cova.y, splat_cova.z,
splat_cova.y, splat_covb.x, splat_covb.y,
splat_cova.z, splat_covb.y, splat_covb.z
);
color = getColor();
float focal = viewport.x * matrix_projection[0][0];
#ifdef DEBUG_RENDER
vec3 local = quatToMat3(rotation) * (vertex_position * scale * 2.0) + center;
gl_Position = matrix_viewProjection * matrix_model * vec4(local, 1.0);
#else
vec3 splat_cova;
vec3 splat_covb;
computeCov3d(mat3(matrix_model) * quatToMat3(rotation), scale, splat_cova, splat_covb);
mat3 J = mat3(
focal / splat_cam.z, 0., -(focal * splat_cam.x) / (splat_cam.z * splat_cam.z),
0., focal / splat_cam.z, -(focal * splat_cam.y) / (splat_cam.z * splat_cam.z),
0., 0., 0.
);
mat3 Vrk = mat3(
splat_cova.x, splat_cova.y, splat_cova.z,
splat_cova.y, splat_covb.x, splat_covb.y,
splat_cova.z, splat_covb.y, splat_covb.z
);
float focal = viewport.x * matrix_projection[0][0];
mat3 J = mat3(
focal / splat_cam.z, 0., -(focal * splat_cam.x) / (splat_cam.z * splat_cam.z),
0., focal / splat_cam.z, -(focal * splat_cam.y) / (splat_cam.z * splat_cam.z),
0., 0., 0.
);
mat3 W = transpose(mat3(matrix_view));
mat3 T = W * J;
mat3 cov = transpose(T) * Vrk * T;
mat3 W = transpose(mat3(matrix_view));
mat3 T = W * J;
mat3 cov = transpose(T) * Vrk * T;
float diagonal1 = cov[0][0] + 0.3;
float offDiagonal = cov[0][1];
float diagonal2 = cov[1][1] + 0.3;
float diagonal1 = cov[0][0] + 0.3;
float offDiagonal = cov[0][1];
float diagonal2 = cov[1][1] + 0.3;
float mid = 0.5 * (diagonal1 + diagonal2);
float radius = length(vec2((diagonal1 - diagonal2) / 2.0, offDiagonal));
Expand All @@ -195,21 +202,11 @@ const splatVS = `
vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector;
vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x);
gl_Position = splat_proj +
vec4((vertex_position.x * v1 + vertex_position.y * v2) / viewport * 2.0,
0.0, 0.0) * splat_proj.w;
texCoord = vertex_position.xy * 2.0;
color = getColor();
#ifdef DEBUG_RENDER
vec3 local = quatToMat3(rotation) * (vertex_position * scale * 2.0) + center;
gl_Position = matrix_viewProjection * matrix_model * vec4(local, 1.0);
color = getColor();
gl_Position = splat_proj +
vec4((vertex_position.x * v1 + vertex_position.y * v2) / viewport * 2.0,
0.0, 0.0) * splat_proj.w;
texCoord = vertex_position.xy * 2.0;
#endif
}
`;
Expand Down
59 changes: 39 additions & 20 deletions src/splat/splat-resource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,50 @@ const mat = new Mat4();
const pos = new Vec3();
const dir = new Vec3();

const debugRender = false;
const debugRenderBounds = false;

class SplatResource extends ContainerResource {
device: GraphicsDevice;
splatData: SplatData;
sortManager: SortManager;

focalPoint = new Vec3();
entity: Entity;
customAabb = new BoundingBox();

renders: RenderComponent[] = [];
meshes: Mesh[] = [];
materials: Material[] = [];
textures: Texture[] = [];

handle: any;
// per-instance data (currently only a single instance is supported)
instances: {
splat: Splat,
sortManager: SortManager,
entity: Entity,
callbackHandle: any
}[] = [];

constructor(device: GraphicsDevice, splatData: SplatData) {
super();

this.device = device;
this.splatData = splatData;

// calculate splat focal point
this.splatData.calcFocalPoint(this.focalPoint);

// calculate custom aabb
this.splatData.calcAabb(this.customAabb);
}

destroy() {
this.handle.off();
this.instances.forEach((instance) => {
instance.splat.destroy();
instance.sortManager.destroy();
instance.entity.destroy();
instance.callbackHandle.off();
});

// TODO
this.splatData = null;
}

instantiateModelEntity(/* options: any */): Entity {
Expand All @@ -55,7 +71,7 @@ class SplatResource extends ContainerResource {

instantiateRenderEntity(options: any): Entity {
const splatData = this.splatData;
const splat = new Splat(this.device, splatData.numSplats, false);
const splat = new Splat(this.device, splatData.numSplats, options?.debugRender ?? debugRender);

splat.updateColorData(splatData.getProp('f_dc_0'), splatData.getProp('f_dc_1'), splatData.getProp('f_dc_2'), splatData.getProp('opacity'));
splat.updateScaleData(splatData.getProp('scale_0'), splatData.getProp('scale_1'), splatData.getProp('scale_2'));
Expand All @@ -69,14 +85,8 @@ class SplatResource extends ContainerResource {
castShadows: false // shadows not supported
});

this.entity = result;

// set custom aabb
const customAabb = new BoundingBox();
this.splatData.calcAabb(customAabb);
result.render.customAabb = customAabb;

this.splatData.calcFocalPoint(this.focalPoint);
result.render.customAabb = this.customAabb;

// centers - constant buffer that is sent to the worker
const x = this.splatData.getProp('x');
Expand All @@ -91,8 +101,8 @@ class SplatResource extends ContainerResource {
}

// initialize sort
this.sortManager = new SortManager();
this.sortManager.sort(
const sortManager = new SortManager();
sortManager.sort(
splat.meshInstance.instancingData.vertexBuffer,
centers,
this.device.isWebGPU,
Expand All @@ -101,17 +111,17 @@ class SplatResource extends ContainerResource {

const viewport = [0, 0];

this.handle = options.app.on('prerender', () => {
const callbackHandle = options.app.on('prerender', () => {
const cameraMat = options.camera.getWorldTransform();
cameraMat.getTranslation(pos);
cameraMat.getZ(dir);

const modelMat = this.entity.getWorldTransform();
const modelMat = result.getWorldTransform();
const invModelMat = mat.invert(modelMat);
invModelMat.transformPoint(pos, pos);
invModelMat.transformVector(dir, dir);

this.sortManager.setCamera(pos, dir);
sortManager.setCamera(pos, dir);

viewport[0] = this.device.width;
viewport[1] = this.device.height;
Expand All @@ -123,11 +133,20 @@ class SplatResource extends ContainerResource {
}
});

// store instance
this.instances.push({
splat,
sortManager,
entity: result,
callbackHandle
});

return result;
}

getFocalPoint(): Vec3 {
return this.entity.getWorldTransform().transformPoint(this.focalPoint);
const instance = this.instances[0];
return instance?.entity.getWorldTransform().transformPoint(this.focalPoint);
}
}

Expand Down
33 changes: 18 additions & 15 deletions src/splat/splat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ const float2Half = (value: number) => {
return bits;
};


const evalTextureSize = (count: number) : Vec2 => {
const width = Math.ceil(Math.sqrt(count));
const height = Math.ceil(count / width);
Expand Down Expand Up @@ -123,7 +122,6 @@ const getTextureFormat = (device: GraphicsDevice, preferHighPrecision: boolean)
};

class Splat {
device: GraphicsDevice;
numSplats: number;
material: Material;
mesh: Mesh;
Expand All @@ -136,32 +134,31 @@ class Splat {
centerTexture: Texture;

constructor(device: GraphicsDevice, numSplats: number, debugRender = false) {
this.device = device;
this.numSplats = numSplats;

// material
this.material = createSplatMaterial(device, debugRender);

// mesh
if (debugRender) {
this.mesh = createBox(this.device, {
this.mesh = createBox(device, {
halfExtents: new Vec3(1.0, 1.0, 1.0)
});
} else {
this.mesh = new Mesh(this.device);
this.mesh = new Mesh(device);
this.mesh.setPositions(new Float32Array([
-1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1
]), 2);
this.mesh.update();
}

// mesh instance
const vertexFormat = new VertexFormat(this.device, [
{ semantic: SEMANTIC_ATTR13, components: 1, type: this.device.isWebGPU ? TYPE_UINT32 : TYPE_FLOAT32 }
const vertexFormat = new VertexFormat(device, [
{ semantic: SEMANTIC_ATTR13, components: 1, type: device.isWebGPU ? TYPE_UINT32 : TYPE_FLOAT32 }
]);

const vertexBuffer = new VertexBuffer(
this.device,
device,
vertexFormat,
numSplats,
BUFFER_DYNAMIC,
Expand All @@ -175,10 +172,10 @@ class Splat {
const size = evalTextureSize(numSplats);

this.format = getTextureFormat(device, false);
this.colorTexture = createTexture(this.device, 'splatColor', PIXELFORMAT_RGBA8, size);
this.scaleTexture = createTexture(this.device, 'splatScale', this.format.format, size);
this.rotationTexture = createTexture(this.device, 'splatRotation', this.format.format, size);
this.centerTexture = createTexture(this.device, 'splatCenter', this.format.format, size);
this.colorTexture = createTexture(device, 'splatColor', PIXELFORMAT_RGBA8, size);
this.scaleTexture = createTexture(device, 'splatScale', this.format.format, size);
this.rotationTexture = createTexture(device, 'splatRotation', this.format.format, size);
this.centerTexture = createTexture(device, 'splatCenter', this.format.format, size);

this.material.setParameter('splatColor', this.colorTexture);
this.material.setParameter('splatScale', this.scaleTexture);
Expand All @@ -187,6 +184,15 @@ class Splat {
this.material.setParameter('tex_params', new Float32Array([size.x, size.y, 1 / size.x, 1 / size.y]));
}

destroy() {
this.colorTexture.destroy();
this.scaleTexture.destroy();
this.rotationTexture.destroy();
this.centerTexture.destroy();
this.material.destroy();
this.mesh.destroy();
}

updateColorData(f_dc_0: TypedArray, f_dc_1: TypedArray, f_dc_2: TypedArray, opacity: TypedArray) {
const SH_C0 = 0.28209479177387814;
const texture = this.colorTexture;
Expand Down Expand Up @@ -218,7 +224,6 @@ class Splat {
}

updateScaleData(scale_0: TypedArray, scale_1: TypedArray, scale_2: TypedArray) {
// texture format based vars
const { numComponents, isHalf } = this.format;
const texture = this.scaleTexture;
const data = texture.lock();
Expand All @@ -244,7 +249,6 @@ class Splat {
}

updateRotationData(rot_0: TypedArray, rot_1: TypedArray, rot_2: TypedArray, rot_3: TypedArray) {
// texture format based vars
const { numComponents, isHalf } = this.format;
const quat = new Quat();

Expand Down Expand Up @@ -274,7 +278,6 @@ class Splat {
}

updateCenterData(x: TypedArray, y: TypedArray, z: TypedArray) {
// texture format based vars
const { numComponents, isHalf } = this.format;

const texture = this.centerTexture;
Expand Down
2 changes: 1 addition & 1 deletion src/viewer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1801,7 +1801,7 @@ class Viewer {
rt.resolve();
}

// perform mulitiframe update. returned flag indicates whether more frames
// perform multiframe update. returned flag indicates whether more frames
// are needed.
this.multiframeBusy = this.multiframe.update();
}
Expand Down

0 comments on commit 44cba90

Please sign in to comment.