Skip to content

Commit

Permalink
Add compute shader 2x2 box downsampling demo
Browse files Browse the repository at this point in the history
  • Loading branch information
httpdigest committed Oct 29, 2021
1 parent 60926aa commit 73b08c6
Show file tree
Hide file tree
Showing 4 changed files with 356 additions and 0 deletions.
81 changes: 81 additions & 0 deletions res/org/lwjgl/demo/opengl/shader/downsampling/downsample.cs.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright LWJGL. All rights reserved.
* License terms: https://www.lwjgl.org/license
*/
#version 430 core
#extension GL_KHR_shader_subgroup_shuffle : require

layout(location=0) uniform sampler2D baseImage;
layout(binding=0, rgba16f) uniform writeonly restrict image2D mips[3];

/*
* The assumption here is that each subgroup item maps to its corresponding local workgroup item
* according to gl_LocalInvocationID.x % gl_SubgroupSize == gl_SubgroupInvocationID.
* Our workgroups are 256 = 16 * 16 items in size.
*
* We will use z-order / morton-curve to layout the 256 threads in a workgroup
* across a 16x16 grid. That means, we still use (width/16, height/16, 1) workgroups
* to process the baseImage. We just redistribute the work items on a different
* 2D pattern.
*/
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

/*
* Morton code unpack to generate (x, y) pair from a z-order curve coordinate within [0..255].
*/
int unpack(int x) {
x &= 0x55;
x = (x ^ (x >> 1)) & 0x33;
x = (x ^ (x >> 2)) & 0x0f;
return x;
}

void main(void) {
ivec2 ts = textureSize(baseImage, 0);

// the actual size of our work items is only half the baseImage size, because for the first mip level
// each work item already uses linear filtering with a sampler to gather a 2x2 texel average
ivec2 s = ts / ivec2(2);

// Compute the (x, y) coordinates of the current work item within its workgroup using z-order curve
ivec2 l = ivec2(unpack(int(gl_LocalInvocationID.x)),
unpack(int(gl_LocalInvocationID.x >> 1u)));

// Compute the global (x, y) coordinate of this work item
ivec2 i = ivec2(gl_WorkGroupID.xy) * ivec2(16) + l;

// compute mip 1 using linear filtering
if (i.x >= s.x || i.y >= s.y)
return;
// Compute a texture coordinate right at the corner between four texels
vec2 tc = (vec2(i * 2) + vec2(1.0)) / vec2(ts);
vec4 t = textureLod(baseImage, tc, 0.0);
imageStore(mips[0], i, t);

// compute mip 2 using subgroup quad sharing
/*
* The trick here is to assume a 1:1 correspondence between subgroup invocation ids
* and workgroup invocation ids (modulus the subgroup size).
* This way, together with our assumed Z-order swizzled layout, we know that
* for the subgroup [0, 1, 2, 3] forming a single 2x2 quad, e.g. the horizontal swap
* will come out correctly as [1, 0, 3, 2], etc.
*/
vec4 h = subgroupShuffleXor(t, 1);
vec4 v = subgroupShuffleXor(t, 2);
vec4 d = subgroupShuffleXor(t, 3);
t = (t + h + v + d) * vec4(0.25);
if ((gl_SubgroupInvocationID & 3) == 0)
imageStore(mips[1], i/ivec2(2), t);

// compute mip 3 using subgroup xor shuffles
/*
* The trick here is to exchange information between subgroup items with a stride
* of 4 items. In order to do this, we have subgroupShuffleXor().
*/
h = subgroupShuffleXor(t, 4);
v = subgroupShuffleXor(t, 8);
d = subgroupShuffleXor(t, 12);
t = (t + h + v + d) * vec4(0.25);
if ((gl_SubgroupInvocationID & 15) == 0)
imageStore(mips[2], i/ivec2(4), t);
}
14 changes: 14 additions & 0 deletions res/org/lwjgl/demo/opengl/shader/downsampling/quad.fs.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright LWJGL. All rights reserved.
* License terms: https://www.lwjgl.org/license
*/
#version 430 core

uniform sampler2D tex;
uniform int level;
in vec2 texcoord;
out vec4 color;

void main(void) {
color = textureLod(tex, texcoord, float(level));
}
15 changes: 15 additions & 0 deletions res/org/lwjgl/demo/opengl/shader/downsampling/quad.vs.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright LWJGL. All rights reserved.
* License terms: https://www.lwjgl.org/license
*/
#version 430 core

out vec2 texcoord;

void main(void) {
vec2 vertex = vec2(-1.0) + vec2(
float((gl_VertexID & 1) << 2),
float((gl_VertexID & 2) << 1));
gl_Position = vec4(vertex, 0.0, 1.0);
texcoord = vertex * 0.5 + vec2(0.5, 0.5);
}
246 changes: 246 additions & 0 deletions src/org/lwjgl/demo/opengl/shader/DownsamplingDemo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
/*
* Copyright LWJGL. All rights reserved.
* License terms: https://www.lwjgl.org/license
*/
package org.lwjgl.demo.opengl.shader;

import static java.lang.Math.*;
import static org.lwjgl.glfw.Callbacks.glfwFreeCallbacks;
import static org.lwjgl.glfw.GLFW.*;
import static org.lwjgl.opengl.GL43C.*;
import static org.lwjgl.system.MemoryUtil.*;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;

import org.lwjgl.demo.opengl.util.DemoUtils;
import org.lwjgl.opengl.GL;
import org.lwjgl.opengl.GLCapabilities;
import org.lwjgl.opengl.GLUtil;
import org.lwjgl.system.Callback;
import org.lwjgl.system.MemoryStack;
import org.lwjgl.system.MemoryUtil;

/**
* Computes 3 mip levels of a texture using only a single compute shader dispatch
* and GL_KHR_shader_subgroup.
*
* @author Kai Burjack
*/
public class DownsamplingDemo {

private static long window;
private static int width = 1024;
private static int height = 768;
private static boolean resetTexture;

private static int nullVao;
private static int computeProgram;
private static int quadProgram;
private static int texture;
private static int levelUniform;
private static int level;

private static Callback debugProc;

private static void createNullVao() {
nullVao = glGenVertexArrays();
}

private static void createQuadProgram() throws IOException {
int program = glCreateProgram();
int vshader = DemoUtils.createShader("org/lwjgl/demo/opengl/shader/downsampling/quad.vs.glsl",
GL_VERTEX_SHADER);
int fshader = DemoUtils.createShader("org/lwjgl/demo/opengl/shader/downsampling/quad.fs.glsl",
GL_FRAGMENT_SHADER);
glAttachShader(program, vshader);
glAttachShader(program, fshader);
glBindFragDataLocation(program, 0, "color");
glLinkProgram(program);
glDetachShader(program, vshader);
glDetachShader(program, fshader);
glDeleteShader(vshader);
glDeleteShader(fshader);
int linked = glGetProgrami(program, GL_LINK_STATUS);
String programLog = glGetProgramInfoLog(program);
if (programLog.trim().length() > 0) {
System.err.println(programLog);
}
if (linked == 0) {
throw new AssertionError("Could not link program");
}
int texUniform = glGetUniformLocation(program, "tex");
levelUniform = glGetUniformLocation(program, "level");
glUseProgram(program);
glUniform1i(texUniform, 0);
glUseProgram(0);
quadProgram = program;
}

private static void createComputeProgram() throws IOException {
int program = glCreateProgram();
int cshader = DemoUtils.createShader("org/lwjgl/demo/opengl/shader/downsampling/downsample.cs.glsl",
GL_COMPUTE_SHADER);
glAttachShader(program, cshader);
glLinkProgram(program);
glDetachShader(program, cshader);
glDeleteShader(cshader);
int linked = glGetProgrami(program, GL_LINK_STATUS);
String programLog = glGetProgramInfoLog(program);
if (programLog.trim().length() > 0) {
System.err.println(programLog);
}
if (linked == 0) {
throw new AssertionError("Could not link program");
}
computeProgram = program;
}

private static byte v(int i) {
return i % 8 == 0 ? (byte) 255 : (byte) 0;
}

private static void createTextures() {
texture = glGenTextures();
glBindTexture(GL_TEXTURE_2D, texture);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST_MIPMAP_LINEAR);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_R, GL_CLAMP_TO_EDGE);
glTexStorage2D(GL_TEXTURE_2D, 4, GL_RGBA16F, width, height);
ByteBuffer pixels = MemoryUtil.memAlloc(width * height * 4);
// fill the first level of the texture with some pattern
for (int y = 0; y < height; y++)
for (int x = 0; x < width; x++)
pixels.put(v(x)).put(v(y)).put((byte) 127).put((byte) 255);
pixels.flip();
glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, width, height, GL_RGBA, GL_UNSIGNED_BYTE, pixels);
MemoryUtil.memFree(pixels);
glBindTexture(GL_TEXTURE_2D, 0);
}

private static void downsample() {
glUseProgram(computeProgram);

// read mip level 0
glBindTexture(GL_TEXTURE_2D, texture);
// write mip levels 1-3
glBindImageTexture(0, texture, 1, false, 0, GL_WRITE_ONLY, GL_RGBA16F);
glBindImageTexture(1, texture, 2, false, 0, GL_WRITE_ONLY, GL_RGBA16F);
glBindImageTexture(2, texture, 3, false, 0, GL_WRITE_ONLY, GL_RGBA16F);

int texelsPerWorkItem = 2;
int numGroupsX = (int) ceil((double) width / texelsPerWorkItem / 16);
int numGroupsY = (int) ceil((double) height / texelsPerWorkItem / 16);

glDispatchCompute(numGroupsX, numGroupsY, 1);
glMemoryBarrier(GL_SHADER_IMAGE_ACCESS_BARRIER_BIT);

/* Reset bindings. */
glBindTexture(GL_TEXTURE_2D, 0);
glBindImageTexture(0, 0, 1, false, 0, GL_WRITE_ONLY, GL_RGBA16F);
glBindImageTexture(1, 0, 2, false, 0, GL_WRITE_ONLY, GL_RGBA16F);
glBindImageTexture(2, 0, 3, false, 0, GL_WRITE_ONLY, GL_RGBA16F);
glUseProgram(0);
}

private static void present() {
glClear(GL_COLOR_BUFFER_BIT);
glViewport(0, 0, width/(1<<level), height/(1<<level));
glUseProgram(quadProgram);
glUniform1i(levelUniform, level);
glBindVertexArray(nullVao);
glBindTexture(GL_TEXTURE_2D, texture);
glDrawArrays(GL_TRIANGLES, 0, 3);
glBindTexture(GL_TEXTURE_2D, 0);
glBindVertexArray(0);
glUseProgram(0);
}

private static void init() throws IOException {
if (!glfwInit())
throw new IllegalStateException("Unable to initialize GLFW");

glfwDefaultWindowHints();
glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);
glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GLFW_TRUE);
glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 4);
glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3);
glfwWindowHint(GLFW_VISIBLE, GLFW_FALSE);
glfwWindowHint(GLFW_RESIZABLE, GLFW_TRUE);

System.out.println("Press arrow up/down to increase/decrease the viewed mip level");
window = glfwCreateWindow(width, height, "Downsampling Demo", NULL, NULL);
if (window == NULL) {
throw new AssertionError("Failed to create the GLFW window");
}
glfwSetKeyCallback(window, (wnd, key, scancode, action, mods) -> {
if (key == GLFW_KEY_ESCAPE && action == GLFW_RELEASE)
glfwSetWindowShouldClose(window, true);
else if (key == GLFW_KEY_UP && action == GLFW_RELEASE)
level = min(3, level + 1);
else if (key == GLFW_KEY_DOWN && action == GLFW_RELEASE)
level = max(0, level - 1);
});
glfwSetFramebufferSizeCallback(window, (wnd, w, h) -> {
if (w > 0 && h > 0 && (width != w || height != h)) {
width = w;
height = h;
resetTexture = true;
}
});

try (MemoryStack frame = MemoryStack.stackPush()) {
IntBuffer framebufferSize = frame.mallocInt(2);
nglfwGetFramebufferSize(window, memAddress(framebufferSize), memAddress(framebufferSize) + 4);
width = framebufferSize.get(0);
height = framebufferSize.get(1);
}

glfwMakeContextCurrent(window);

GLCapabilities caps = GL.createCapabilities();
// Check required extensions
if (!caps.GL_KHR_shader_subgroup)
throw new AssertionError("GL_KHR_shader_subgroup is required but not supported");

debugProc = GLUtil.setupDebugMessageCallback();

createTextures();
createNullVao();
createComputeProgram();
createQuadProgram();

glfwShowWindow(window);
}

private static void loop() {
while (!glfwWindowShouldClose(window)) {
glfwPollEvents();
if (resetTexture) {
glDeleteTextures(texture);
createTextures();
resetTexture = false;
}
downsample();
present();
glfwSwapBuffers(window);
}
}

private static void destroy() {
if (debugProc != null)
debugProc.free();
glfwDestroyWindow(window);
glfwFreeCallbacks(window);
glfwTerminate();
}

public static void main(String[] args) throws IOException {
init();
loop();
destroy();
}

}

0 comments on commit 73b08c6

Please sign in to comment.