This commit is contained in:
2026-03-04 00:50:15 -08:00
parent 9126175569
commit 4211317c03
569 changed files with 122194 additions and 0 deletions

View File

@ -0,0 +1,36 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_CPU_COMPUTE
#include <Jolt/Compute/CPU/ComputeBufferCPU.h>
JPH_NAMESPACE_BEGIN
ComputeBufferCPU::ComputeBufferCPU(EType inType, uint64 inSize, uint inStride, const void *inData) :
ComputeBuffer(inType, inSize, inStride)
{
size_t buffer_size = size_t(mSize) * mStride;
mData = Allocate(buffer_size);
if (inData != nullptr)
memcpy(mData, inData, buffer_size);
}
ComputeBufferCPU::~ComputeBufferCPU()
{
Free(mData);
}
ComputeBufferResult ComputeBufferCPU::CreateReadBackBuffer() const
{
ComputeBufferResult result;
result.Set(const_cast<ComputeBufferCPU *>(this));
return result;
}
JPH_NAMESPACE_END
#endif // JPH_USE_CPU_COMPUTE

View File

@ -0,0 +1,36 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Compute/ComputeBuffer.h>
#ifdef JPH_USE_CPU_COMPUTE
JPH_NAMESPACE_BEGIN
/// Buffer that can be used with the CPU compute system
class JPH_EXPORT ComputeBufferCPU final : public ComputeBuffer
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Constructor / destructor
ComputeBufferCPU(EType inType, uint64 inSize, uint inStride, const void *inData);
virtual ~ComputeBufferCPU() override;
ComputeBufferResult CreateReadBackBuffer() const override;
void * GetData() const { return mData; }
private:
virtual void * MapInternal(EMode inMode) override { return mData; }
virtual void UnmapInternal() override { /* Nothing to do */ }
void * mData;
};
JPH_NAMESPACE_END
#endif // JPH_USE_CPU_COMPUTE

View File

@ -0,0 +1,101 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_CPU_COMPUTE
#include <Jolt/Compute/CPU/ComputeQueueCPU.h>
#include <Jolt/Compute/CPU/ComputeShaderCPU.h>
#include <Jolt/Compute/CPU/ComputeBufferCPU.h>
#include <Jolt/Compute/CPU/ShaderWrapper.h>
#include <Jolt/Compute/CPU/HLSLToCPP.h>
JPH_NAMESPACE_BEGIN
ComputeQueueCPU::~ComputeQueueCPU()
{
JPH_ASSERT(mShader == nullptr && mWrapper == nullptr);
}
void ComputeQueueCPU::SetShader(const ComputeShader *inShader)
{
JPH_ASSERT(mShader == nullptr && mWrapper == nullptr);
mShader = static_cast<const ComputeShaderCPU *>(inShader);
mWrapper = mShader->CreateWrapper();
}
void ComputeQueueCPU::SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer)
{
if (inBuffer == nullptr)
return;
JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::ConstantBuffer);
const ComputeBufferCPU *buffer = static_cast<const ComputeBufferCPU *>(inBuffer);
mWrapper->Bind(inName, buffer->GetData(), buffer->GetSize() * buffer->GetStride());
mUsedBuffers.insert(buffer);
}
void ComputeQueueCPU::SetBuffer(const char *inName, const ComputeBuffer *inBuffer)
{
if (inBuffer == nullptr)
return;
JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::UploadBuffer || inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
const ComputeBufferCPU *buffer = static_cast<const ComputeBufferCPU *>(inBuffer);
mWrapper->Bind(inName, buffer->GetData(), buffer->GetSize() * buffer->GetStride());
mUsedBuffers.insert(buffer);
}
void ComputeQueueCPU::SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier)
{
if (inBuffer == nullptr)
return;
JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
const ComputeBufferCPU *buffer = static_cast<const ComputeBufferCPU *>(inBuffer);
mWrapper->Bind(inName, buffer->GetData(), buffer->GetSize() * buffer->GetStride());
mUsedBuffers.insert(buffer);
}
void ComputeQueueCPU::ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc)
{
/* Nothing to read back */
}
void ComputeQueueCPU::Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ)
{
uint nx = inThreadGroupsX * mShader->GetGroupSizeX();
uint ny = inThreadGroupsY * mShader->GetGroupSizeY();
uint nz = inThreadGroupsZ * mShader->GetGroupSizeZ();
for (uint z = 0; z < nz; ++z)
for (uint y = 0; y < ny; ++y)
for (uint x = 0; x < nx; ++x)
{
HLSLToCPP::uint3 tid { x, y, z };
mWrapper->Main(tid);
}
delete mWrapper;
mWrapper = nullptr;
mUsedBuffers.clear();
mShader = nullptr;
}
void ComputeQueueCPU::Execute()
{
/* Nothing to do */
}
void ComputeQueueCPU::Wait()
{
/* Nothing to do */
}
JPH_NAMESPACE_END
#endif // JPH_USE_CPU_COMPUTE

View File

@ -0,0 +1,43 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Compute/ComputeQueue.h>
#ifdef JPH_USE_CPU_COMPUTE
#include <Jolt/Compute/CPU/ComputeShaderCPU.h>
#include <Jolt/Core/UnorderedSet.h>
JPH_NAMESPACE_BEGIN
/// A command queue for the CPU compute system
class JPH_EXPORT ComputeQueueCPU final : public ComputeQueue
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Destructor
virtual ~ComputeQueueCPU() override;
// See: ComputeQueue
virtual void SetShader(const ComputeShader *inShader) override;
virtual void SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
virtual void SetBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
virtual void SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier = EBarrier::Yes) override;
virtual void ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc) override;
virtual void Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ) override;
virtual void Execute() override;
virtual void Wait() override;
private:
RefConst<ComputeShaderCPU> mShader = nullptr; ///< Current active shader
ShaderWrapper * mWrapper = nullptr; ///< The active shader wrapper
UnorderedSet<RefConst<ComputeBuffer>> mUsedBuffers; ///< Buffers that are in use by the current execution, these will be retained until execution is finished so that we don't free buffers that are in use
};
JPH_NAMESPACE_END
#endif // JPH_USE_CPU_COMPUTE

View File

@ -0,0 +1,42 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Compute/ComputeShader.h>
#ifdef JPH_USE_CPU_COMPUTE
JPH_NAMESPACE_BEGIN
class ShaderWrapper;
/// Compute shader handle for CPU compute
class JPH_EXPORT ComputeShaderCPU : public ComputeShader
{
public:
JPH_OVERRIDE_NEW_DELETE
using CreateShader = ShaderWrapper *(*)();
/// Constructor
ComputeShaderCPU(CreateShader inCreateShader, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) :
ComputeShader(inGroupSizeX, inGroupSizeY, inGroupSizeZ),
mCreateShader(inCreateShader)
{
}
/// Create an instance of the shader wrapper
ShaderWrapper * CreateWrapper() const
{
return mCreateShader();
}
private:
CreateShader mCreateShader;
};
JPH_NAMESPACE_END
#endif // JPH_USE_CPU_COMPUTE

View File

@ -0,0 +1,56 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_CPU_COMPUTE
#include <Jolt/Compute/CPU/ComputeSystemCPU.h>
#include <Jolt/Compute/CPU/ComputeQueueCPU.h>
#include <Jolt/Compute/CPU/ComputeBufferCPU.h>
JPH_NAMESPACE_BEGIN
JPH_IMPLEMENT_RTTI_VIRTUAL(ComputeSystemCPU)
{
JPH_ADD_BASE_CLASS(ComputeSystemCPU, ComputeSystem)
}
ComputeShaderResult ComputeSystemCPU::CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ)
{
ComputeShaderResult result;
const ShaderRegistry::const_iterator it = mShaderRegistry.find(inName);
if (it == mShaderRegistry.end())
{
result.SetError("Compute shader not found");
return result;
}
result.Set(new ComputeShaderCPU(it->second, inGroupSizeX, inGroupSizeY, inGroupSizeZ));
return result;
}
ComputeBufferResult ComputeSystemCPU::CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData)
{
ComputeBufferResult result;
result.Set(new ComputeBufferCPU(inType, inSize, inStride, inData));
return result;
}
ComputeQueueResult ComputeSystemCPU::CreateComputeQueue()
{
ComputeQueueResult result;
result.Set(new ComputeQueueCPU());
return result;
}
ComputeSystemResult CreateComputeSystemCPU()
{
ComputeSystemResult result;
result.Set(new ComputeSystemCPU());
return result;
}
JPH_NAMESPACE_END
#endif // JPH_USE_CPU_COMPUTE

View File

@ -0,0 +1,52 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Compute/ComputeSystem.h>
#ifdef JPH_USE_CPU_COMPUTE
#include <Jolt/Core/UnorderedMap.h>
#include <Jolt/Compute/CPU/ComputeShaderCPU.h>
JPH_NAMESPACE_BEGIN
/// Interface to run a workload on the CPU
/// This is intended mainly for debugging purposes and is not optimized for performance
class JPH_EXPORT ComputeSystemCPU : public ComputeSystem
{
public:
JPH_DECLARE_RTTI_VIRTUAL(JPH_EXPORT, ComputeSystemCPU)
// See: ComputeSystem
virtual ComputeShaderResult CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) override;
virtual ComputeBufferResult CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData = nullptr) override;
virtual ComputeQueueResult CreateComputeQueue() override;
using CreateShader = ComputeShaderCPU::CreateShader;
void RegisterShader(const char *inName, CreateShader inCreateShader)
{
mShaderRegistry[inName] = inCreateShader;
}
private:
using ShaderRegistry = UnorderedMap<string_view, CreateShader>;
ShaderRegistry mShaderRegistry;
};
// Internal helpers
#define JPH_SHADER_WRAPPER_FUNCTION_NAME(name) RegisterShader##name
#define JPH_SHADER_WRAPPER_FUNCTION(sys, name) void JPH_EXPORT JPH_SHADER_WRAPPER_FUNCTION_NAME(name)(ComputeSystemCPU *sys)
/// Macro to declare a shader register function
#define JPH_DECLARE_REGISTER_SHADER(name) namespace JPH { class ComputeSystemCPU; JPH_SHADER_WRAPPER_FUNCTION(, name); }
/// Macro to register a shader
#define JPH_REGISTER_SHADER(sys, name) JPH::JPH_SHADER_WRAPPER_FUNCTION_NAME(name)(sys)
JPH_NAMESPACE_END
#endif // JPH_USE_CPU_COMPUTE

View File

@ -0,0 +1,525 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
JPH_NAMESPACE_BEGIN
/// Emulates HLSL vector types and operations in C++.
/// Note doesn't emulate things like barriers and group shared memory.
namespace HLSLToCPP {
using std::sqrt;
using std::min;
using std::max;
using std::round;
//////////////////////////////////////////////////////////////////////////////////////////
// float2
//////////////////////////////////////////////////////////////////////////////////////////
struct float2
{
// Constructors
inline float2() = default;
constexpr float2(float inX, float inY) : x(inX), y(inY) { }
explicit constexpr float2(float inS) : x(inS), y(inS) { }
// Operators
constexpr float2 & operator += (const float2 &inRHS) { x += inRHS.x; y += inRHS.y; return *this; }
constexpr float2 & operator -= (const float2 &inRHS) { x -= inRHS.x; y -= inRHS.y; return *this; }
constexpr float2 & operator *= (float inRHS) { x *= inRHS; y *= inRHS; return *this; }
constexpr float2 & operator /= (float inRHS) { x /= inRHS; y /= inRHS; return *this; }
constexpr float2 & operator *= (const float2 &inRHS) { x *= inRHS.x; y *= inRHS.y; return *this; }
constexpr float2 & operator /= (const float2 &inRHS) { x /= inRHS.x; y /= inRHS.y; return *this; }
// Equality
constexpr bool operator == (const float2 &inRHS) const { return x == inRHS.x && y == inRHS.y; }
constexpr bool operator != (const float2 &inRHS) const { return !(*this == inRHS); }
// Component access
const float & operator [] (uint inIndex) const { return (&x)[inIndex]; }
float & operator [] (uint inIndex) { return (&x)[inIndex]; }
// Swizzling (note return value is const to prevent assignment to swizzled results)
const float2 swizzle_xy() const { return float2(x, y); }
const float2 swizzle_yx() const { return float2(y, x); }
float x, y;
};
// Operators
constexpr float2 operator - (const float2 &inA) { return float2(-inA.x, -inA.y); }
constexpr float2 operator + (const float2 &inA, const float2 &inB) { return float2(inA.x + inB.x, inA.y + inB.y); }
constexpr float2 operator - (const float2 &inA, const float2 &inB) { return float2(inA.x - inB.x, inA.y - inB.y); }
constexpr float2 operator * (const float2 &inA, const float2 &inB) { return float2(inA.x * inB.x, inA.y * inB.y); }
constexpr float2 operator / (const float2 &inA, const float2 &inB) { return float2(inA.x / inB.x, inA.y / inB.y); }
constexpr float2 operator * (const float2 &inA, float inS) { return float2(inA.x * inS, inA.y * inS); }
constexpr float2 operator * (float inS, const float2 &inA) { return inA * inS; }
constexpr float2 operator / (const float2 &inA, float inS) { return float2(inA.x / inS, inA.y / inS); }
// Dot product
constexpr float dot(const float2 &inA, const float2 &inB) { return inA.x * inB.x + inA.y * inB.y; }
// Min value
constexpr float2 min(const float2 &inA, const float2 &inB) { return float2(min(inA.x, inB.x), min(inA.y, inB.y)); }
// Max value
constexpr float2 max(const float2 &inA, const float2 &inB) { return float2(max(inA.x, inB.x), max(inA.y, inB.y)); }
// Length
inline float length(const float2 &inV) { return sqrt(dot(inV, inV)); }
// Normalization
inline float2 normalize(const float2 &inV) { return inV / length(inV); }
// Rounding to int
inline float2 round(const float2 &inV) { return float2(round(inV.x), round(inV.y)); }
//////////////////////////////////////////////////////////////////////////////////////////
// float3
//////////////////////////////////////////////////////////////////////////////////////////
struct uint3;
struct float3
{
// Constructors
inline float3() = default;
constexpr float3(const float2 &inV, float inZ) : x(inV.x), y(inV.y), z(inZ) { }
constexpr float3(float inX, float inY, float inZ) : x(inX), y(inY), z(inZ) { }
explicit constexpr float3(float inS) : x(inS), y(inS), z(inS) { }
explicit constexpr float3(const uint3 &inV);
// Operators
constexpr float3 & operator += (const float3 &inRHS) { x += inRHS.x; y += inRHS.y; z += inRHS.z; return *this; }
constexpr float3 & operator -= (const float3 &inRHS) { x -= inRHS.x; y -= inRHS.y; z -= inRHS.z; return *this; }
constexpr float3 & operator *= (float inRHS) { x *= inRHS; y *= inRHS; z *= inRHS; return *this; }
constexpr float3 & operator /= (float inRHS) { x /= inRHS; y /= inRHS; z /= inRHS; return *this; }
constexpr float3 & operator *= (const float3 &inRHS) { x *= inRHS.x; y *= inRHS.y; z *= inRHS.z; return *this; }
constexpr float3 & operator /= (const float3 &inRHS) { x /= inRHS.x; y /= inRHS.y; z /= inRHS.z; return *this; }
// Equality
constexpr bool operator == (const float3 &inRHS) const { return x == inRHS.x && y == inRHS.y && z == inRHS.z; }
constexpr bool operator != (const float3 &inRHS) const { return !(*this == inRHS); }
// Component access
const float & operator [] (uint inIndex) const { return (&x)[inIndex]; }
float & operator [] (uint inIndex) { return (&x)[inIndex]; }
// Swizzling (note return value is const to prevent assignment to swizzled results)
const float2 swizzle_xy() const { return float2(x, y); }
const float2 swizzle_yx() const { return float2(y, x); }
const float3 swizzle_xyz() const { return float3(x, y, z); }
const float3 swizzle_xzy() const { return float3(x, z, y); }
const float3 swizzle_yxz() const { return float3(y, x, z); }
const float3 swizzle_yzx() const { return float3(y, z, x); }
const float3 swizzle_zxy() const { return float3(z, x, y); }
const float3 swizzle_zyx() const { return float3(z, y, x); }
float x, y, z;
};
// Operators
constexpr float3 operator - (const float3 &inA) { return float3(-inA.x, -inA.y, -inA.z); }
constexpr float3 operator + (const float3 &inA, const float3 &inB) { return float3(inA.x + inB.x, inA.y + inB.y, inA.z + inB.z); }
constexpr float3 operator - (const float3 &inA, const float3 &inB) { return float3(inA.x - inB.x, inA.y - inB.y, inA.z - inB.z); }
constexpr float3 operator * (const float3 &inA, const float3 &inB) { return float3(inA.x * inB.x, inA.y * inB.y, inA.z * inB.z); }
constexpr float3 operator / (const float3 &inA, const float3 &inB) { return float3(inA.x / inB.x, inA.y / inB.y, inA.z / inB.z); }
constexpr float3 operator * (const float3 &inA, float inS) { return float3(inA.x * inS, inA.y * inS, inA.z * inS); }
constexpr float3 operator * (float inS, const float3 &inA) { return inA * inS; }
constexpr float3 operator / (const float3 &inA, float inS) { return float3(inA.x / inS, inA.y / inS, inA.z / inS); }
// Dot product
constexpr float dot(const float3 &inA, const float3 &inB) { return inA.x * inB.x + inA.y * inB.y + inA.z * inB.z; }
// Min value
constexpr float3 min(const float3 &inA, const float3 &inB) { return float3(min(inA.x, inB.x), min(inA.y, inB.y), min(inA.z, inB.z)); }
// Max value
constexpr float3 max(const float3 &inA, const float3 &inB) { return float3(max(inA.x, inB.x), max(inA.y, inB.y), max(inA.z, inB.z)); }
// Length
inline float length(const float3 &inV) { return sqrt(dot(inV, inV)); }
// Normalization
inline float3 normalize(const float3 &inV) { return inV / length(inV); }
// Rounding to int
inline float3 round(const float3 &inV) { return float3(round(inV.x), round(inV.y), round(inV.z)); }
// Cross product
constexpr float3 cross(const float3 &inA, const float3 &inB) { return float3(inA.y * inB.z - inA.z * inB.y, inA.z * inB.x - inA.x * inB.z, inA.x * inB.y - inA.y * inB.x); }
//////////////////////////////////////////////////////////////////////////////////////////
// float4
//////////////////////////////////////////////////////////////////////////////////////////
struct int4;
struct float4
{
// Constructors
inline float4() = default;
constexpr float4(const float3 &inV, float inW) : x(inV.x), y(inV.y), z(inV.z), w(inW) { }
constexpr float4(float inX, float inY, float inZ, float inW) : x(inX), y(inY), z(inZ), w(inW) { }
explicit constexpr float4(float inS) : x(inS), y(inS), z(inS), w(inS) { }
explicit constexpr float4(const int4 &inV);
// Operators
constexpr float4 & operator += (const float4 &inRHS) { x += inRHS.x; y += inRHS.y; z += inRHS.z; w += inRHS.w; return *this; }
constexpr float4 & operator -= (const float4 &inRHS) { x -= inRHS.x; y -= inRHS.y; z -= inRHS.z; w -= inRHS.w; return *this; }
constexpr float4 & operator *= (float inRHS) { x *= inRHS; y *= inRHS; z *= inRHS; w *= inRHS; return *this; }
constexpr float4 & operator /= (float inRHS) { x /= inRHS; y /= inRHS; z /= inRHS; w /= inRHS; return *this; }
constexpr float4 & operator *= (const float4 &inRHS) { x *= inRHS.x; y *= inRHS.y; z *= inRHS.z; w *= inRHS.w; return *this; }
constexpr float4 & operator /= (const float4 &inRHS) { x /= inRHS.x; y /= inRHS.y; z /= inRHS.z; w /= inRHS.w; return *this; }
// Equality
constexpr bool operator == (const float4 &inRHS) const { return x == inRHS.x && y == inRHS.y && z == inRHS.z && w == inRHS.w; }
constexpr bool operator != (const float4 &inRHS) const { return !(*this == inRHS); }
// Component access
const float & operator [] (uint inIndex) const { return (&x)[inIndex]; }
float & operator [] (uint inIndex) { return (&x)[inIndex]; }
// Swizzling (note return value is const to prevent assignment to swizzled results)
const float2 swizzle_xy() const { return float2(x, y); }
const float2 swizzle_yx() const { return float2(y, x); }
const float3 swizzle_xyz() const { return float3(x, y, z); }
const float3 swizzle_xzy() const { return float3(x, z, y); }
const float3 swizzle_yxz() const { return float3(y, x, z); }
const float3 swizzle_yzx() const { return float3(y, z, x); }
const float3 swizzle_zxy() const { return float3(z, x, y); }
const float3 swizzle_zyx() const { return float3(z, y, x); }
const float4 swizzle_xywz() const { return float4(x, y, w, z); }
const float4 swizzle_xwyz() const { return float4(x, w, y, z); }
const float4 swizzle_wxyz() const { return float4(w, x, y, z); }
float x, y, z, w;
};
// Operators
constexpr float4 operator - (const float4 &inA) { return float4(-inA.x, -inA.y, -inA.z, -inA.w); }
constexpr float4 operator + (const float4 &inA, const float4 &inB) { return float4(inA.x + inB.x, inA.y + inB.y, inA.z + inB.z, inA.w + inB.w); }
constexpr float4 operator - (const float4 &inA, const float4 &inB) { return float4(inA.x - inB.x, inA.y - inB.y, inA.z - inB.z, inA.w - inB.w); }
constexpr float4 operator * (const float4 &inA, const float4 &inB) { return float4(inA.x * inB.x, inA.y * inB.y, inA.z * inB.z, inA.w * inB.w); }
constexpr float4 operator / (const float4 &inA, const float4 &inB) { return float4(inA.x / inB.x, inA.y / inB.y, inA.z / inB.z, inA.w / inB.w); }
constexpr float4 operator * (const float4 &inA, float inS) { return float4(inA.x * inS, inA.y * inS, inA.z * inS, inA.w * inS); }
constexpr float4 operator * (float inS, const float4 &inA) { return inA * inS; }
constexpr float4 operator / (const float4 &inA, float inS) { return float4(inA.x / inS, inA.y / inS, inA.z / inS, inA.w / inS); }
// Dot product
constexpr float dot(const float4 &inA, const float4 &inB) { return inA.x * inB.x + inA.y * inB.y + inA.z * inB.z + inA.w * inB.w; }
// Min value
constexpr float4 min(const float4 &inA, const float4 &inB) { return float4(min(inA.x, inB.x), min(inA.y, inB.y), min(inA.z, inB.z), min(inA.w, inB.w)); }
// Max value
constexpr float4 max(const float4 &inA, const float4 &inB) { return float4(max(inA.x, inB.x), max(inA.y, inB.y), max(inA.z, inB.z), max(inA.w, inB.w)); }
// Length
inline float length(const float4 &inV) { return sqrt(dot(inV, inV)); }
// Normalization
inline float4 normalize(const float4 &inV) { return inV / length(inV); }
// Rounding to int
inline float4 round(const float4 &inV) { return float4(round(inV.x), round(inV.y), round(inV.z), round(inV.w)); }
//////////////////////////////////////////////////////////////////////////////////////////
// uint3
//////////////////////////////////////////////////////////////////////////////////////////
struct uint3
{
inline uint3() = default;
constexpr uint3(uint32 inX, uint32 inY, uint32 inZ) : x(inX), y(inY), z(inZ) { }
explicit constexpr uint3(const float3 &inV) : x(uint32(inV.x)), y(uint32(inV.y)), z(uint32(inV.z)) { }
// Operators
constexpr uint3 & operator += (const uint3 &inRHS) { x += inRHS.x; y += inRHS.y; z += inRHS.z; return *this; }
constexpr uint3 & operator -= (const uint3 &inRHS) { x -= inRHS.x; y -= inRHS.y; z -= inRHS.z; return *this; }
constexpr uint3 & operator *= (uint32 inRHS) { x *= inRHS; y *= inRHS; z *= inRHS; return *this; }
constexpr uint3 & operator /= (uint32 inRHS) { x /= inRHS; y /= inRHS; z /= inRHS; return *this; }
constexpr uint3 & operator *= (const uint3 &inRHS) { x *= inRHS.x; y *= inRHS.y; z *= inRHS.z; return *this; }
constexpr uint3 & operator /= (const uint3 &inRHS) { x /= inRHS.x; y /= inRHS.y; z /= inRHS.z; return *this; }
// Equality
constexpr bool operator == (const uint3 &inRHS) const { return x == inRHS.x && y == inRHS.y && z == inRHS.z; }
constexpr bool operator != (const uint3 &inRHS) const { return !(*this == inRHS); }
// Component access
const uint32 & operator [] (uint inIndex) const { return (&x)[inIndex]; }
uint32 & operator [] (uint inIndex) { return (&x)[inIndex]; }
// Swizzling (note return value is const to prevent assignment to swizzled results)
const uint3 swizzle_xyz() const { return uint3(x, y, z); }
const uint3 swizzle_xzy() const { return uint3(x, z, y); }
const uint3 swizzle_yxz() const { return uint3(y, x, z); }
const uint3 swizzle_yzx() const { return uint3(y, z, x); }
const uint3 swizzle_zxy() const { return uint3(z, x, y); }
const uint3 swizzle_zyx() const { return uint3(z, y, x); }
uint32 x, y, z;
};
// Operators
constexpr uint3 operator + (const uint3 &inA, const uint3 &inB) { return uint3(inA.x + inB.x, inA.y + inB.y, inA.z + inB.z); }
constexpr uint3 operator - (const uint3 &inA, const uint3 &inB) { return uint3(inA.x - inB.x, inA.y - inB.y, inA.z - inB.z); }
constexpr uint3 operator * (const uint3 &inA, const uint3 &inB) { return uint3(inA.x * inB.x, inA.y * inB.y, inA.z * inB.z); }
constexpr uint3 operator / (const uint3 &inA, const uint3 &inB) { return uint3(inA.x / inB.x, inA.y / inB.y, inA.z / inB.z); }
constexpr uint3 operator * (const uint3 &inA, uint32 inS) { return uint3(inA.x * inS, inA.y * inS, inA.z * inS); }
constexpr uint3 operator * (uint32 inS, const uint3 &inA) { return inA * inS; }
constexpr uint3 operator / (const uint3 &inA, uint32 inS) { return uint3(inA.x / inS, inA.y / inS, inA.z / inS); }
// Dot product
constexpr uint32 dot(const uint3 &inA, const uint3 &inB) { return inA.x * inB.x + inA.y * inB.y + inA.z * inB.z; }
// Min value
constexpr uint3 min(const uint3 &inA, const uint3 &inB) { return uint3(min(inA.x, inB.x), min(inA.y, inB.y), min(inA.z, inB.z)); }
// Max value
constexpr uint3 max(const uint3 &inA, const uint3 &inB) { return uint3(max(inA.x, inB.x), max(inA.y, inB.y), max(inA.z, inB.z)); }
//////////////////////////////////////////////////////////////////////////////////////////
// uint4
//////////////////////////////////////////////////////////////////////////////////////////
struct uint4
{
// Constructors
inline uint4() = default;
constexpr uint4(const uint3 &inV, uint32 inW) : x(inV.x), y(inV.y), z(inV.z), w(inW) { }
constexpr uint4(uint32 inX, uint32 inY, uint32 inZ, uint32 inW) : x(inX), y(inY), z(inZ), w(inW) { }
explicit constexpr uint4(uint32 inS) : x(inS), y(inS), z(inS), w(inS) { }
// Operators
constexpr uint4 & operator += (const uint4 &inRHS) { x += inRHS.x; y += inRHS.y; z += inRHS.z; w += inRHS.w; return *this; }
constexpr uint4 & operator -= (const uint4 &inRHS) { x -= inRHS.x; y -= inRHS.y; z -= inRHS.z; w -= inRHS.w; return *this; }
constexpr uint4 & operator *= (uint32 inRHS) { x *= inRHS; y *= inRHS; z *= inRHS; w *= inRHS; return *this; }
constexpr uint4 & operator /= (uint32 inRHS) { x /= inRHS; y /= inRHS; z /= inRHS; w /= inRHS; return *this; }
constexpr uint4 & operator *= (const uint4 &inRHS) { x *= inRHS.x; y *= inRHS.y; z *= inRHS.z; w *= inRHS.w; return *this; }
constexpr uint4 & operator /= (const uint4 &inRHS) { x /= inRHS.x; y /= inRHS.y; z /= inRHS.z; w /= inRHS.w; return *this; }
// Equality
constexpr bool operator == (const uint4 &inRHS) const { return x == inRHS.x && y == inRHS.y && z == inRHS.z && w == inRHS.w; }
constexpr bool operator != (const uint4 &inRHS) const { return !(*this == inRHS); }
// Component access
const uint32 & operator [] (uint inIndex) const { return (&x)[inIndex]; }
uint32 & operator [] (uint inIndex) { return (&x)[inIndex]; }
// Swizzling (note return value is const to prevent assignment to swizzled results)
const uint3 swizzle_xyz() const { return uint3(x, y, z); }
const uint3 swizzle_xzy() const { return uint3(x, z, y); }
const uint3 swizzle_yxz() const { return uint3(y, x, z); }
const uint3 swizzle_yzx() const { return uint3(y, z, x); }
const uint3 swizzle_zxy() const { return uint3(z, x, y); }
const uint3 swizzle_zyx() const { return uint3(z, y, x); }
const uint4 swizzle_xywz() const { return uint4(x, y, w, z); }
const uint4 swizzle_xwyz() const { return uint4(x, w, y, z); }
const uint4 swizzle_wxyz() const { return uint4(w, x, y, z); }
uint32 x, y, z, w;
};
// Operators
constexpr uint4 operator + (const uint4 &inA, const uint4 &inB) { return uint4(inA.x + inB.x, inA.y + inB.y, inA.z + inB.z, inA.w + inB.w); }
constexpr uint4 operator - (const uint4 &inA, const uint4 &inB) { return uint4(inA.x - inB.x, inA.y - inB.y, inA.z - inB.z, inA.w - inB.w); }
constexpr uint4 operator * (const uint4 &inA, const uint4 &inB) { return uint4(inA.x * inB.x, inA.y * inB.y, inA.z * inB.z, inA.w * inB.w); }
constexpr uint4 operator / (const uint4 &inA, const uint4 &inB) { return uint4(inA.x / inB.x, inA.y / inB.y, inA.z / inB.z, inA.w / inB.w); }
constexpr uint4 operator * (const uint4 &inA, uint32 inS) { return uint4(inA.x * inS, inA.y * inS, inA.z * inS, inA.w * inS); }
constexpr uint4 operator * (uint32 inS, const uint4 &inA) { return inA * inS; }
constexpr uint4 operator / (const uint4 &inA, uint32 inS) { return uint4(inA.x / inS, inA.y / inS, inA.z / inS, inA.w / inS); }
// Dot product
constexpr uint32 dot(const uint4 &inA, const uint4 &inB) { return inA.x * inB.x + inA.y * inB.y + inA.z * inB.z + inA.w * inB.w; }
// Min value
constexpr uint4 min(const uint4 &inA, const uint4 &inB) { return uint4(min(inA.x, inB.x), min(inA.y, inB.y), min(inA.z, inB.z), min(inA.w, inB.w)); }
// Max value
constexpr uint4 max(const uint4 &inA, const uint4 &inB) { return uint4(max(inA.x, inB.x), max(inA.y, inB.y), max(inA.z, inB.z), max(inA.w, inB.w)); }
//////////////////////////////////////////////////////////////////////////////////////////
// int3
//////////////////////////////////////////////////////////////////////////////////////////
struct int3
{
inline int3() = default;
constexpr int3(int inX, int inY, int inZ) : x(inX), y(inY), z(inZ) { }
explicit constexpr int3(const float3 &inV) : x(int(inV.x)), y(int(inV.y)), z(int(inV.z)) { }
// Operators
constexpr int3 & operator += (const int3 &inRHS) { x += inRHS.x; y += inRHS.y; z += inRHS.z; return *this; }
constexpr int3 & operator -= (const int3 &inRHS) { x -= inRHS.x; y -= inRHS.y; z -= inRHS.z; return *this; }
constexpr int3 & operator *= (int inRHS) { x *= inRHS; y *= inRHS; z *= inRHS; return *this; }
constexpr int3 & operator /= (int inRHS) { x /= inRHS; y /= inRHS; z /= inRHS; return *this; }
constexpr int3 & operator *= (const int3 &inRHS) { x *= inRHS.x; y *= inRHS.y; z *= inRHS.z; return *this; }
constexpr int3 & operator /= (const int3 &inRHS) { x /= inRHS.x; y /= inRHS.y; z /= inRHS.z; return *this; }
// Equality
constexpr bool operator == (const int3 &inRHS) const { return x == inRHS.x && y == inRHS.y && z == inRHS.z; }
constexpr bool operator != (const int3 &inRHS) const { return !(*this == inRHS); }
// Component access
const int & operator [] (uint inIndex) const { return (&x)[inIndex]; }
int & operator [] (uint inIndex) { return (&x)[inIndex]; }
// Swizzling (note return value is const to prevent assignment to swizzled results)
const int3 swizzle_xyz() const { return int3(x, y, z); }
const int3 swizzle_xzy() const { return int3(x, z, y); }
const int3 swizzle_yxz() const { return int3(y, x, z); }
const int3 swizzle_yzx() const { return int3(y, z, x); }
const int3 swizzle_zxy() const { return int3(z, x, y); }
const int3 swizzle_zyx() const { return int3(z, y, x); }
int x, y, z;
};
// Operators
constexpr int3 operator - (const int3 &inA) { return int3(-inA.x, -inA.y, -inA.z); }
constexpr int3 operator + (const int3 &inA, const int3 &inB) { return int3(inA.x + inB.x, inA.y + inB.y, inA.z + inB.z); }
constexpr int3 operator - (const int3 &inA, const int3 &inB) { return int3(inA.x - inB.x, inA.y - inB.y, inA.z - inB.z); }
constexpr int3 operator * (const int3 &inA, const int3 &inB) { return int3(inA.x * inB.x, inA.y * inB.y, inA.z * inB.z); }
constexpr int3 operator / (const int3 &inA, const int3 &inB) { return int3(inA.x / inB.x, inA.y / inB.y, inA.z / inB.z); }
constexpr int3 operator * (const int3 &inA, int inS) { return int3(inA.x * inS, inA.y * inS, inA.z * inS); }
constexpr int3 operator * (int inS, const int3 &inA) { return inA * inS; }
constexpr int3 operator / (const int3 &inA, int inS) { return int3(inA.x / inS, inA.y / inS, inA.z / inS); }
// Dot product
constexpr int dot(const int3 &inA, const int3 &inB) { return inA.x * inB.x + inA.y * inB.y + inA.z * inB.z; }
// Min value
constexpr int3 min(const int3 &inA, const int3 &inB) { return int3(min(inA.x, inB.x), min(inA.y, inB.y), min(inA.z, inB.z)); }
// Max value
constexpr int3 max(const int3 &inA, const int3 &inB) { return int3(max(inA.x, inB.x), max(inA.y, inB.y), max(inA.z, inB.z)); }
//////////////////////////////////////////////////////////////////////////////////////////
// int4
//////////////////////////////////////////////////////////////////////////////////////////
struct int4
{
// Constructors
inline int4() = default;
constexpr int4(const int3 &inV, int inW) : x(inV.x), y(inV.y), z(inV.z), w(inW) { }
constexpr int4(int inX, int inY, int inZ, int inW) : x(inX), y(inY), z(inZ), w(inW) { }
explicit constexpr int4(int inS) : x(inS), y(inS), z(inS), w(inS) { }
explicit constexpr int4(const float4 &inV) : x(int(inV.x)), y(int(inV.y)), z(int(inV.z)), w(int(inV.w)) { }
// Operators
constexpr int4 & operator += (const int4 &inRHS) { x += inRHS.x; y += inRHS.y; z += inRHS.z; w += inRHS.w; return *this; }
constexpr int4 & operator -= (const int4 &inRHS) { x -= inRHS.x; y -= inRHS.y; z -= inRHS.z; w -= inRHS.w; return *this; }
constexpr int4 & operator *= (int inRHS) { x *= inRHS; y *= inRHS; z *= inRHS; w *= inRHS; return *this; }
constexpr int4 & operator /= (int inRHS) { x /= inRHS; y /= inRHS; z /= inRHS; w /= inRHS; return *this; }
constexpr int4 & operator *= (const int4 &inRHS) { x *= inRHS.x; y *= inRHS.y; z *= inRHS.z; w *= inRHS.w; return *this; }
constexpr int4 & operator /= (const int4 &inRHS) { x /= inRHS.x; y /= inRHS.y; z /= inRHS.z; w /= inRHS.w; return *this; }
// Equality
constexpr bool operator == (const int4 &inRHS) const { return x == inRHS.x && y == inRHS.y && z == inRHS.z && w == inRHS.w; }
constexpr bool operator != (const int4 &inRHS) const { return !(*this == inRHS); }
// Component access
const int & operator [] (uint inIndex) const { return (&x)[inIndex]; }
int & operator [] (uint inIndex) { return (&x)[inIndex]; }
// Swizzling (note return value is const to prevent assignment to swizzled results)
const int3 swizzle_xyz() const { return int3(x, y, z); }
const int3 swizzle_xzy() const { return int3(x, z, y); }
const int3 swizzle_yxz() const { return int3(y, x, z); }
const int3 swizzle_yzx() const { return int3(y, z, x); }
const int3 swizzle_zxy() const { return int3(z, x, y); }
const int3 swizzle_zyx() const { return int3(z, y, x); }
const int4 swizzle_xywz() const { return int4(x, y, w, z); }
const int4 swizzle_xwyz() const { return int4(x, w, y, z); }
const int4 swizzle_wxyz() const { return int4(w, x, y, z); }
int x, y, z, w;
};
// Operators
constexpr int4 operator - (const int4 &inA) { return int4(-inA.x, -inA.y, -inA.z, -inA.w); }
constexpr int4 operator + (const int4 &inA, const int4 &inB) { return int4(inA.x + inB.x, inA.y + inB.y, inA.z + inB.z, inA.w + inB.w); }
constexpr int4 operator - (const int4 &inA, const int4 &inB) { return int4(inA.x - inB.x, inA.y - inB.y, inA.z - inB.z, inA.w - inB.w); }
constexpr int4 operator * (const int4 &inA, const int4 &inB) { return int4(inA.x * inB.x, inA.y * inB.y, inA.z * inB.z, inA.w * inB.w); }
constexpr int4 operator / (const int4 &inA, const int4 &inB) { return int4(inA.x / inB.x, inA.y / inB.y, inA.z / inB.z, inA.w / inB.w); }
constexpr int4 operator * (const int4 &inA, int inS) { return int4(inA.x * inS, inA.y * inS, inA.z * inS, inA.w * inS); }
constexpr int4 operator * (int inS, const int4 &inA) { return inA * inS; }
constexpr int4 operator / (const int4 &inA, int inS) { return int4(inA.x / inS, inA.y / inS, inA.z / inS, inA.w / inS); }
// Dot product
constexpr int dot(const int4 &inA, const int4 &inB) { return inA.x * inB.x + inA.y * inB.y + inA.z * inB.z + inA.w * inB.w; }
// Min value
constexpr int4 min(const int4 &inA, const int4 &inB) { return int4(min(inA.x, inB.x), min(inA.y, inB.y), min(inA.z, inB.z), min(inA.w, inB.w)); }
// Max value
constexpr int4 max(const int4 &inA, const int4 &inB) { return int4(max(inA.x, inB.x), max(inA.y, inB.y), max(inA.z, inB.z), max(inA.w, inB.w)); }
//////////////////////////////////////////////////////////////////////////////////////////
// Mat44
//////////////////////////////////////////////////////////////////////////////////////////
struct Mat44
{
// Constructors
inline Mat44() = default;
constexpr Mat44(const float4 &inC0, const float4 &inC1, const float4 &inC2, const float4 &inC3) : c { inC0, inC1, inC2, inC3 } { }
// Columns
float4 & operator [] (uint inIndex) { return c[inIndex]; }
const float4 & operator [] (uint inIndex) const { return c[inIndex]; }
private:
float4 c[4];
};
//////////////////////////////////////////////////////////////////////////////////////////
// Other types
//////////////////////////////////////////////////////////////////////////////////////////
using Quat = float4;
using Plane = float4;
// Clamp value
template <class T>
constexpr T clamp(const T &inValue, const T &inMinValue, const T &inMaxValue)
{
return min(max(inValue, inMinValue), inMaxValue);
}
// Atomic add
template <class T>
T JPH_AtomicAdd(T &ioT, const T &inValue)
{
std::atomic<T> *value = reinterpret_cast<std::atomic<T> *>(&ioT);
return value->fetch_add(inValue) + inValue;
}
// Bitcast float4 to int4
inline int4 asint(const float4 &inV) { return int4(BitCast<int>(inV.x), BitCast<int>(inV.y), BitCast<int>(inV.z), BitCast<int>(inV.w)); }
// Functions that couldn't be declared earlier
constexpr float3::float3(const uint3 &inV) : x(float(inV.x)), y(float(inV.y)), z(float(inV.z)) { }
constexpr float4::float4(const int4 &inV) : x(float(inV.x)), y(float(inV.y)), z(float(inV.z)), w(float(inV.w)) { }
// Swizzle operators
#define xy swizzle_xy()
#define yx swizzle_yx()
#define xyz swizzle_xyz()
#define xzy swizzle_xzy()
#define yxz swizzle_yxz()
#define yzx swizzle_yzx()
#define zxy swizzle_zxy()
#define zyx swizzle_zyx()
#define xywz swizzle_xywz()
#define xwyz swizzle_xwyz()
#define wxyz swizzle_wxyz()
} // HLSLToCPP
JPH_NAMESPACE_END

View File

@ -0,0 +1,29 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#ifdef JPH_USE_CPU_COMPUTE
JPH_NAMESPACE_BEGIN
namespace HLSLToCPP { struct uint3; }
/// Wraps a compute shader to allow calling it from C++
class ShaderWrapper
{
public:
/// Destructor
virtual ~ShaderWrapper() = default;
/// Bind buffer to shader
virtual void Bind(const char *inName, void *inData, uint64 inSize) = 0;
/// Execute a single shader thread
virtual void Main(const HLSLToCPP::uint3 &inThreadID) = 0;
};
JPH_NAMESPACE_END
#endif // JPH_USE_CPU_COMPUTE

View File

@ -0,0 +1,75 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Core/HashCombine.h>
#include <Jolt/Compute/CPU/ComputeSystemCPU.h>
#include <Jolt/Compute/CPU/ShaderWrapper.h>
#include <Jolt/Compute/CPU/HLSLToCPP.h>
/// @cond INTERNAL
JPH_NAMESPACE_BEGIN
JPH_MSVC_SUPPRESS_WARNING(5031) // #pragma warning(pop): likely mismatch, popping warning state pushed in different file
#define JPH_SHADER_OVERRIDE_MACROS
#define JPH_SHADER_GENERATE_WRAPPER
#define JPH_SHADER_CONSTANT(type, name, value) inline static constexpr type name = value;
#define JPH_SHADER_CONSTANTS_BEGIN(type, name) struct type { alignas(16) int dummy; } name; // Ensure that the first constant is 16 byte aligned
#define JPH_SHADER_CONSTANTS_MEMBER(type, name) type c##name;
#define JPH_SHADER_CONSTANTS_END(type)
#define JPH_SHADER_BUFFER(type) const type *
#define JPH_SHADER_RW_BUFFER(type) type *
#define JPH_SHADER_BIND_BEGIN(name)
#define JPH_SHADER_BIND_END(name)
#define JPH_SHADER_BIND_BUFFER(type, name) const type *name = nullptr;
#define JPH_SHADER_BIND_RW_BUFFER(type, name) type *name = nullptr;
#define JPH_SHADER_FUNCTION_BEGIN(return_type, name, group_size_x, group_size_y, group_size_z) \
virtual void Main(
#define JPH_SHADER_PARAM_THREAD_ID(name) const HLSLToCPP::uint3 &name
#define JPH_SHADER_FUNCTION_END ) override
#define JPH_SHADER_STRUCT_BEGIN(name) struct name {
#define JPH_SHADER_STRUCT_MEMBER(type, name) type m##name;
#define JPH_SHADER_STRUCT_END(name) };
#define JPH_TO_STRING(name) JPH_TO_STRING2(name)
#define JPH_TO_STRING2(name) #name
#define JPH_SHADER_CLASS_NAME(name) JPH_SHADER_CLASS_NAME2(name)
#define JPH_SHADER_CLASS_NAME2(name) name##ShaderWrapper
#define JPH_IN(type) const type &
#define JPH_OUT(type) type &
#define JPH_IN_OUT(type) type &
// Namespace to prevent 'using' from leaking out
namespace ShaderWrappers {
using namespace HLSLToCPP;
class JPH_SHADER_CLASS_NAME(JPH_SHADER_NAME) : public ShaderWrapper
{
public:
// Define types
using JPH_float = float;
using JPH_float3 = HLSLToCPP::float3;
using JPH_float4 = HLSLToCPP::float4;
using JPH_uint = uint;
using JPH_uint3 = HLSLToCPP::uint3;
using JPH_uint4 = HLSLToCPP::uint4;
using JPH_int = int;
using JPH_int3 = HLSLToCPP::int3;
using JPH_int4 = HLSLToCPP::int4;
using JPH_Quat = HLSLToCPP::Quat;
using JPH_Plane = HLSLToCPP::Plane;
using JPH_Mat44 = HLSLToCPP::Mat44;
// Now the shader code should be included followed by WrapShaderBindings.h
/// @endcond

View File

@ -0,0 +1,40 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
/// @cond INTERNAL
// First WrapShaderBegin.h should have been included, then the shader code
/// Bind a buffer to the shader
virtual void Bind(const char *inName, void *inData, uint64 inSize) override
{
// Don't redefine constants
#undef JPH_SHADER_CONSTANT
#define JPH_SHADER_CONSTANT(type, name, value)
// Don't redefine structs
#undef JPH_SHADER_STRUCT_BEGIN
#undef JPH_SHADER_STRUCT_MEMBER
#undef JPH_SHADER_STRUCT_END
#define JPH_SHADER_STRUCT_BEGIN(name)
#define JPH_SHADER_STRUCT_MEMBER(type, name)
#define JPH_SHADER_STRUCT_END(name)
// When a constant buffer is bound, copy the data into the members
#undef JPH_SHADER_CONSTANTS_BEGIN
#undef JPH_SHADER_CONSTANTS_MEMBER
#define JPH_SHADER_CONSTANTS_BEGIN(type, name) case HashString(#name): memcpy(&name + 1, inData, size_t(inSize)); break; // Very hacky way to get the address of the first constant and to copy the entire block of constants
#define JPH_SHADER_CONSTANTS_MEMBER(type, name)
// When a buffer is bound, set the pointer
#undef JPH_SHADER_BIND_BUFFER
#undef JPH_SHADER_BIND_RW_BUFFER
#define JPH_SHADER_BIND_BUFFER(type, name) case HashString(#name): name = (const type *)inData; break;
#define JPH_SHADER_BIND_RW_BUFFER(type, name) case HashString(#name): name = (type *)inData; break;
switch (HashString(inName))
{
// Now include the shader bindings followed by WrapShaderEnd.h
/// @endcond

View File

@ -0,0 +1,61 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
/// @cond INTERNAL
// WrapShaderBindings.h should have been included followed by the shader bindings
default:
JPH_ASSERT(false, "Buffer cannot be bound to this shader");
break;
}
}
/// Factory function to create a shader wrapper for this shader
static ShaderWrapper * sCreate()
{
return new JPH_SHADER_CLASS_NAME(JPH_SHADER_NAME)();
}
};
} // ShaderWrappers
/// @endcond
// Stop clang from complaining that the register function is missing a prototype
JPH_SHADER_WRAPPER_FUNCTION(, JPH_SHADER_NAME);
/// Register this wrapper
JPH_SHADER_WRAPPER_FUNCTION(inComputeSystem, JPH_SHADER_NAME)
{
inComputeSystem->RegisterShader(JPH_TO_STRING(JPH_SHADER_NAME), ShaderWrappers::JPH_SHADER_CLASS_NAME(JPH_SHADER_NAME)::sCreate);
}
#undef JPH_SHADER_OVERRIDE_MACROS
#undef JPH_SHADER_GENERATE_WRAPPER
#undef JPH_SHADER_CONSTANT
#undef JPH_SHADER_CONSTANTS_BEGIN
#undef JPH_SHADER_CONSTANTS_MEMBER
#undef JPH_SHADER_CONSTANTS_END
#undef JPH_SHADER_BUFFER
#undef JPH_SHADER_RW_BUFFER
#undef JPH_SHADER_BIND_BEGIN
#undef JPH_SHADER_BIND_END
#undef JPH_SHADER_BIND_BUFFER
#undef JPH_SHADER_BIND_RW_BUFFER
#undef JPH_SHADER_FUNCTION_BEGIN
#undef JPH_SHADER_PARAM_THREAD_ID
#undef JPH_SHADER_FUNCTION_END
#undef JPH_SHADER_STRUCT_BEGIN
#undef JPH_SHADER_STRUCT_MEMBER
#undef JPH_SHADER_STRUCT_END
#undef JPH_TO_STRING
#undef JPH_TO_STRING2
#undef JPH_SHADER_CLASS_NAME
#undef JPH_SHADER_CLASS_NAME2
#undef JPH_OUT
#undef JPH_IN_OUT
#undef JPH_SHADER_NAME
JPH_NAMESPACE_END

View File

@ -0,0 +1,69 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Core/Reference.h>
#include <Jolt/Core/NonCopyable.h>
#include <Jolt/Core/Result.h>
JPH_NAMESPACE_BEGIN
class ComputeBuffer;
using ComputeBufferResult = Result<Ref<ComputeBuffer>>;
/// Buffer that can be read from / written to by a compute shader
class JPH_EXPORT ComputeBuffer : public RefTarget<ComputeBuffer>, public NonCopyable
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Type of buffer
enum class EType
{
UploadBuffer, ///< Buffer that can be written on the CPU and then uploaded to the GPU.
ReadbackBuffer, ///< Buffer to be sent from the GPU to the CPU, used to read back data.
ConstantBuffer, ///< A smallish buffer that is used to pass constants to a shader.
Buffer, ///< Buffer that can be read from by a shader. Must be initialized with data at construction time and is read only thereafter.
RWBuffer, ///< Buffer that can be read from and written to by a shader.
};
/// Constructor / Destructor
ComputeBuffer(EType inType, uint64 inSize, uint inStride) : mType(inType), mSize(inSize), mStride(inStride) { }
virtual ~ComputeBuffer() { JPH_ASSERT(!mIsMapped); }
/// Properties
EType GetType() const { return mType; }
uint64 GetSize() const { return mSize; }
uint GetStride() const { return mStride; }
/// Mode in which the buffer is accessed
enum class EMode
{
Read, ///< Read only access to the buffer
Write, ///< Write only access to the buffer (this will discard all previous data in the buffer)
};
/// Map / unmap buffer (get pointer to data).
void * Map(EMode inMode) { JPH_ASSERT(!mIsMapped); JPH_IF_ENABLE_ASSERTS(mIsMapped = true;) return MapInternal(inMode); }
template <typename T> T * Map(EMode inMode) { JPH_ASSERT(!mIsMapped); JPH_IF_ENABLE_ASSERTS(mIsMapped = true;) JPH_ASSERT(sizeof(T) == mStride); return reinterpret_cast<T *>(MapInternal(inMode)); }
void Unmap() { JPH_ASSERT(mIsMapped); JPH_IF_ENABLE_ASSERTS(mIsMapped = false;) UnmapInternal(); }
/// Create a readback buffer of the same size and stride that can be used to read the data stored in this buffer on CPU.
/// Note that this could also be implemented as 'return this' in case the underlying implementation allows locking GPU data on CPU directly.
virtual ComputeBufferResult CreateReadBackBuffer() const = 0;
protected:
EType mType;
uint64 mSize;
uint mStride;
#ifdef JPH_ENABLE_ASSERTS
bool mIsMapped = false;
#endif // JPH_ENABLE_ASSERTS
virtual void * MapInternal(EMode inMode) = 0;
virtual void UnmapInternal() = 0;
};
JPH_NAMESPACE_END

View File

@ -0,0 +1,83 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Core/Reference.h>
#include <Jolt/Core/NonCopyable.h>
#include <Jolt/Core/Result.h>
JPH_NAMESPACE_BEGIN
class ComputeShader;
class ComputeBuffer;
/// A command queue for executing compute workloads on the GPU.
///
/// Note that only a single thread should be using a ComputeQueue at any time (although an implementation could be made that is thread safe).
class JPH_EXPORT ComputeQueue : public RefTarget<ComputeQueue>, public NonCopyable
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Destructor
virtual ~ComputeQueue() = default;
/// Activate a shader. Shader must be set first before buffers can be bound.
/// After every Dispatch call, the shader must be set again and all buffers must be bound again.
virtual void SetShader(const ComputeShader *inShader) = 0;
/// If a barrier should be placed before accessing the buffer
enum class EBarrier
{
Yes,
No
};
/// Bind a constant buffer to the shader. Note that the contents of the buffer cannot be modified until execution finishes.
/// A reference to the buffer is added to make sure it stays alive until execution finishes.
/// @param inName Name of the buffer as specified in the shader.
/// @param inBuffer The buffer to bind.
virtual void SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer) = 0;
/// Bind a read only buffer to the shader. Note that the contents of the buffer cannot be modified on CPU until execution finishes (only relevant for buffers of type UploadBuffer).
/// A reference to the buffer is added to make sure it stays alive until execution finishes.
/// @param inName Name of the buffer as specified in the shader.
/// @param inBuffer The buffer to bind.
virtual void SetBuffer(const char *inName, const ComputeBuffer *inBuffer) = 0;
/// Bind a read/write buffer to the shader.
/// A reference to the buffer is added to make sure it stays alive until execution finishes.
/// @param inName Name of the buffer as specified in the shader.
/// @param inBuffer The buffer to bind.
/// @param inBarrier If set to Yes, a barrier will be placed before accessing the buffer to ensure all previous writes to the buffer are visible.
virtual void SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier = EBarrier::Yes) = 0;
/// Dispatch a compute shader with the specified number of thread groups
virtual void Dispatch(uint inThreadGroupsX, uint inThreadGroupsY = 1, uint inThreadGroupsZ = 1) = 0;
/// Schedule buffer to be copied from GPU to CPU.
/// A reference to the buffers is added to make sure they stay alive until execution finishes.
virtual void ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc) = 0;
/// Execute accumulated command list.
/// No more commands can be added until Wait is called.
virtual void Execute() = 0;
/// After executing, this waits until execution is done.
/// This also makes sure that any readback operations have completed and the data is available on CPU.
virtual void Wait() = 0;
/// Execute and wait for the command list to finish
/// @see Execute, Wait
void ExecuteAndWait()
{
Execute();
Wait();
}
};
using ComputeQueueResult = Result<Ref<ComputeQueue>>;
JPH_NAMESPACE_END

View File

@ -0,0 +1,41 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Core/Reference.h>
#include <Jolt/Core/NonCopyable.h>
#include <Jolt/Core/Result.h>
JPH_NAMESPACE_BEGIN
/// Compute shader handle
class JPH_EXPORT ComputeShader : public RefTarget<ComputeShader>, public NonCopyable
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Constructor / destructor
ComputeShader(uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) :
mGroupSizeX(inGroupSizeX),
mGroupSizeY(inGroupSizeY),
mGroupSizeZ(inGroupSizeZ)
{
}
virtual ~ComputeShader() = default;
/// Get group sizes
uint32 GetGroupSizeX() const { return mGroupSizeX; }
uint32 GetGroupSizeY() const { return mGroupSizeY; }
uint32 GetGroupSizeZ() const { return mGroupSizeZ; }
private:
uint32 mGroupSizeX;
uint32 mGroupSizeY;
uint32 mGroupSizeZ;
};
using ComputeShaderResult = Result<Ref<ComputeShader>>;
JPH_NAMESPACE_END

View File

@ -0,0 +1,15 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#include <Jolt/Compute/ComputeSystem.h>
JPH_NAMESPACE_BEGIN
JPH_IMPLEMENT_RTTI_ABSTRACT_BASE(ComputeSystem)
{
}
JPH_NAMESPACE_END

View File

@ -0,0 +1,78 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Compute/ComputeShader.h>
#include <Jolt/Compute/ComputeBuffer.h>
#include <Jolt/Compute/ComputeQueue.h>
#include <Jolt/Core/RTTI.h>
JPH_NAMESPACE_BEGIN
/// Interface to run a workload on the GPU
class JPH_EXPORT ComputeSystem : public RefTarget<ComputeSystem>, public NonCopyable
{
public:
JPH_DECLARE_RTTI_ABSTRACT_BASE(JPH_EXPORT, ComputeSystem)
/// Destructor
virtual ~ComputeSystem() = default;
/// Compile a compute shader
virtual ComputeShaderResult CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY = 1, uint32 inGroupSizeZ = 1) = 0;
/// Create a buffer for use with a compute shader
virtual ComputeBufferResult CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData = nullptr) = 0;
/// Create a queue for executing compute shaders
virtual ComputeQueueResult CreateComputeQueue() = 0;
/// Callback used when loading shaders
using ShaderLoader = std::function<bool(const char *inName, Array<uint8> &outData, String &outError)>;
ShaderLoader mShaderLoader = [](const char *, Array<uint8> &, String &outError) { JPH_ASSERT(false, "Override this function"); outError = "Not implemented"; return false; };
};
using ComputeSystemResult = Result<Ref<ComputeSystem>>;
#ifdef JPH_USE_VK
/// Factory function to create a compute system using Vulkan
extern JPH_EXPORT ComputeSystemResult CreateComputeSystemVK();
#endif
#ifdef JPH_USE_CPU_COMPUTE
/// Factory function to create a compute system that falls back to CPU.
/// This is intended mainly for debugging purposes and is not optimized for performance
extern JPH_EXPORT ComputeSystemResult CreateComputeSystemCPU();
#endif
#ifdef JPH_USE_DX12
/// Factory function to create a compute system using DirectX 12
extern JPH_EXPORT ComputeSystemResult CreateComputeSystemDX12();
/// Factory function to create the default compute system for this platform
inline ComputeSystemResult CreateComputeSystem() { return CreateComputeSystemDX12(); }
#elif defined(JPH_USE_MTL)
/// Factory function to create a compute system using Metal
extern JPH_EXPORT ComputeSystemResult CreateComputeSystemMTL();
/// Factory function to create the default compute system for this platform
inline ComputeSystemResult CreateComputeSystem() { return CreateComputeSystemMTL(); }
#elif defined(JPH_USE_VK)
/// Factory function to create the default compute system for this platform
inline ComputeSystemResult CreateComputeSystem() { return CreateComputeSystemVK(); }
#else
/// Fallback implementation when no compute system is available
inline ComputeSystemResult CreateComputeSystem() { ComputeSystemResult result; result.SetError("Not implemented"); return result; }
#endif
JPH_NAMESPACE_END

View File

@ -0,0 +1,167 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_DX12
#include <Jolt/Compute/DX12/ComputeBufferDX12.h>
#include <Jolt/Compute/DX12/ComputeSystemDX12.h>
JPH_NAMESPACE_BEGIN
ComputeBufferDX12::ComputeBufferDX12(ComputeSystemDX12 *inComputeSystem, EType inType, uint64 inSize, uint inStride) :
ComputeBuffer(inType, inSize, inStride),
mComputeSystem(inComputeSystem)
{
}
bool ComputeBufferDX12::Initialize(const void *inData)
{
uint64 buffer_size = mSize * mStride;
switch (mType)
{
case EType::UploadBuffer:
mBufferCPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_FLAG_NONE, buffer_size);
mBufferGPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_FLAG_NONE, buffer_size);
if (mBufferCPU == nullptr || mBufferGPU == nullptr)
return false;
break;
case EType::ConstantBuffer:
mBufferCPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_FLAG_NONE, buffer_size);
if (mBufferCPU == nullptr)
return false;
break;
case EType::ReadbackBuffer:
JPH_ASSERT(inData == nullptr, "Can't upload data to a readback buffer");
mBufferCPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_READBACK, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_FLAG_NONE, buffer_size);
if (mBufferCPU == nullptr)
return false;
break;
case EType::Buffer:
JPH_ASSERT(inData != nullptr);
mBufferCPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_FLAG_NONE, buffer_size);
mBufferGPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_FLAG_NONE, buffer_size);
if (mBufferCPU == nullptr || mBufferGPU == nullptr)
return false;
mNeedsSync = true;
break;
case EType::RWBuffer:
if (inData != nullptr)
{
mBufferCPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_FLAG_NONE, buffer_size);
if (mBufferCPU == nullptr)
return false;
mNeedsSync = true;
}
mBufferGPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS, buffer_size);
if (mBufferGPU == nullptr)
return false;
break;
}
// Copy data to upload buffer
if (inData != nullptr)
{
void *data = nullptr;
D3D12_RANGE range = { 0, 0 }; // We're not going to read
mBufferCPU->Map(0, &range, &data);
memcpy(data, inData, size_t(buffer_size));
mBufferCPU->Unmap(0, nullptr);
}
return true;
}
bool ComputeBufferDX12::Barrier(ID3D12GraphicsCommandList *inCommandList, D3D12_RESOURCE_STATES inTo) const
{
// Check if state changed
if (mCurrentState == inTo)
return false;
// Only buffers in GPU memory can change state
if (mType != ComputeBuffer::EType::Buffer && mType != ComputeBuffer::EType::RWBuffer)
return true;
D3D12_RESOURCE_BARRIER barrier;
barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION;
barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
barrier.Transition.pResource = GetResourceGPU();
barrier.Transition.StateBefore = mCurrentState;
barrier.Transition.StateAfter = inTo;
barrier.Transition.Subresource = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES;
inCommandList->ResourceBarrier(1, &barrier);
mCurrentState = inTo;
return true;
}
void ComputeBufferDX12::RWBarrier(ID3D12GraphicsCommandList *inCommandList)
{
JPH_ASSERT(mCurrentState == D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
D3D12_RESOURCE_BARRIER barrier;
barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_UAV;
barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
barrier.Transition.pResource = GetResourceGPU();
inCommandList->ResourceBarrier(1, &barrier);
}
bool ComputeBufferDX12::SyncCPUToGPU(ID3D12GraphicsCommandList *inCommandList) const
{
if (!mNeedsSync)
return false;
Barrier(inCommandList, D3D12_RESOURCE_STATE_COPY_DEST);
inCommandList->CopyResource(GetResourceGPU(), GetResourceCPU());
mNeedsSync = false;
return true;
}
void *ComputeBufferDX12::MapInternal(EMode inMode)
{
void *mapped_resource = nullptr;
switch (inMode)
{
case EMode::Read:
JPH_ASSERT(mType == EType::ReadbackBuffer);
if (HRFailed(mBufferCPU->Map(0, nullptr, &mapped_resource)))
return nullptr;
break;
case EMode::Write:
{
JPH_ASSERT(mType == EType::UploadBuffer || mType == EType::ConstantBuffer);
D3D12_RANGE range = { 0, 0 }; // We're not going to read
if (HRFailed(mBufferCPU->Map(0, &range, &mapped_resource)))
return nullptr;
mNeedsSync = true;
}
break;
}
return mapped_resource;
}
void ComputeBufferDX12::UnmapInternal()
{
mBufferCPU->Unmap(0, nullptr);
}
ComputeBufferResult ComputeBufferDX12::CreateReadBackBuffer() const
{
return mComputeSystem->CreateComputeBuffer(EType::ReadbackBuffer, mSize, mStride);
}
JPH_NAMESPACE_END
#endif // JPH_USE_DX12

View File

@ -0,0 +1,51 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Compute/ComputeBuffer.h>
#ifdef JPH_USE_DX12
#include <Jolt/Compute/DX12/IncludeDX12.h>
JPH_NAMESPACE_BEGIN
class ComputeSystemDX12;
/// Buffer that can be read from / written to by a compute shader
class JPH_EXPORT ComputeBufferDX12 final : public ComputeBuffer
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Constructor
ComputeBufferDX12(ComputeSystemDX12 *inComputeSystem, EType inType, uint64 inSize, uint inStride);
bool Initialize(const void *inData);
ID3D12Resource * GetResourceCPU() const { return mBufferCPU.Get(); }
ID3D12Resource * GetResourceGPU() const { return mBufferGPU.Get(); }
ComPtr<ID3D12Resource> ReleaseResourceCPU() const { return std::move(mBufferCPU); }
bool Barrier(ID3D12GraphicsCommandList *inCommandList, D3D12_RESOURCE_STATES inTo) const;
void RWBarrier(ID3D12GraphicsCommandList *inCommandList);
bool SyncCPUToGPU(ID3D12GraphicsCommandList *inCommandList) const;
ComputeBufferResult CreateReadBackBuffer() const override;
private:
virtual void * MapInternal(EMode inMode) override;
virtual void UnmapInternal() override;
ComputeSystemDX12 * mComputeSystem;
mutable ComPtr<ID3D12Resource> mBufferCPU;
ComPtr<ID3D12Resource> mBufferGPU;
mutable bool mNeedsSync = false; ///< If this buffer needs to be synced from CPU to GPU
mutable D3D12_RESOURCE_STATES mCurrentState = D3D12_RESOURCE_STATE_COPY_DEST; ///< State of the GPU buffer so we can do proper barriers
};
JPH_NAMESPACE_END
#endif // JPH_USE_DX12

View File

@ -0,0 +1,221 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_DX12
#include <Jolt/Compute/DX12/ComputeQueueDX12.h>
#include <Jolt/Compute/DX12/ComputeShaderDX12.h>
#include <Jolt/Compute/DX12/ComputeBufferDX12.h>
JPH_NAMESPACE_BEGIN
ComputeQueueDX12::~ComputeQueueDX12()
{
Wait();
if (mFenceEvent != INVALID_HANDLE_VALUE)
CloseHandle(mFenceEvent);
}
bool ComputeQueueDX12::Initialize(ID3D12Device *inDevice, D3D12_COMMAND_LIST_TYPE inType, ComputeQueueResult &outResult)
{
D3D12_COMMAND_QUEUE_DESC queue_desc = {};
queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
queue_desc.Type = inType;
queue_desc.Priority = D3D12_COMMAND_QUEUE_PRIORITY_HIGH;
if (HRFailed(inDevice->CreateCommandQueue(&queue_desc, IID_PPV_ARGS(&mCommandQueue)), outResult))
return false;
if (HRFailed(inDevice->CreateCommandAllocator(inType, IID_PPV_ARGS(&mCommandAllocator)), outResult))
return false;
// Create the command list
if (HRFailed(inDevice->CreateCommandList(0, inType, mCommandAllocator.Get(), nullptr, IID_PPV_ARGS(&mCommandList)), outResult))
return false;
// Command lists are created in the recording state, but there is nothing to record yet. The main loop expects it to be closed, so close it now
if (HRFailed(mCommandList->Close(), outResult))
return false;
// Create synchronization object
if (HRFailed(inDevice->CreateFence(mFenceValue, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&mFence)), outResult))
return false;
// Increment fence value so we don't skip waiting the first time a command list is executed
mFenceValue++;
// Create an event handle to use for frame synchronization
mFenceEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr);
if (HRFailed(HRESULT_FROM_WIN32(GetLastError()), outResult))
return false;
return true;
}
ID3D12GraphicsCommandList *ComputeQueueDX12::Start()
{
JPH_ASSERT(!mIsExecuting);
if (!mIsStarted)
{
// Reset the allocator
if (HRFailed(mCommandAllocator->Reset()))
return nullptr;
// Reset the command list
if (HRFailed(mCommandList->Reset(mCommandAllocator.Get(), nullptr)))
return nullptr;
// Now we have started recording commands
mIsStarted = true;
}
return mCommandList.Get();
}
void ComputeQueueDX12::SetShader(const ComputeShader *inShader)
{
ID3D12GraphicsCommandList *command_list = Start();
mShader = static_cast<const ComputeShaderDX12 *>(inShader);
command_list->SetPipelineState(mShader->GetPipelineState());
command_list->SetComputeRootSignature(mShader->GetRootSignature());
}
void ComputeQueueDX12::SyncCPUToGPU(const ComputeBufferDX12 *inBuffer)
{
// Ensure that any CPU writes are visible to the GPU
if (inBuffer->SyncCPUToGPU(mCommandList.Get())
&& (inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer))
{
// After the first upload, the CPU buffer is no longer needed for Buffer and RWBuffer types
mDelayedFreedBuffers.emplace_back(inBuffer->ReleaseResourceCPU());
}
}
void ComputeQueueDX12::SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer)
{
if (inBuffer == nullptr)
return;
JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::ConstantBuffer);
ID3D12GraphicsCommandList *command_list = Start();
const ComputeBufferDX12 *buffer = static_cast<const ComputeBufferDX12 *>(inBuffer);
command_list->SetComputeRootConstantBufferView(mShader->NameToIndex(inName), buffer->GetResourceCPU()->GetGPUVirtualAddress());
mUsedBuffers.insert(buffer);
}
void ComputeQueueDX12::SetBuffer(const char *inName, const ComputeBuffer *inBuffer)
{
if (inBuffer == nullptr)
return;
JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::UploadBuffer || inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
ID3D12GraphicsCommandList *command_list = Start();
const ComputeBufferDX12 *buffer = static_cast<const ComputeBufferDX12 *>(inBuffer);
uint parameter_index = mShader->NameToIndex(inName);
SyncCPUToGPU(buffer);
buffer->Barrier(command_list, D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE);
command_list->SetComputeRootShaderResourceView(parameter_index, buffer->GetResourceGPU()->GetGPUVirtualAddress());
mUsedBuffers.insert(buffer);
}
void ComputeQueueDX12::SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier)
{
if (inBuffer == nullptr)
return;
JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
ID3D12GraphicsCommandList *command_list = Start();
ComputeBufferDX12 *buffer = static_cast<ComputeBufferDX12 *>(inBuffer);
uint parameter_index = mShader->NameToIndex(inName);
SyncCPUToGPU(buffer);
if (!buffer->Barrier(command_list, D3D12_RESOURCE_STATE_UNORDERED_ACCESS) && inBarrier == EBarrier::Yes)
buffer->RWBarrier(command_list);
command_list->SetComputeRootUnorderedAccessView(parameter_index, buffer->GetResourceGPU()->GetGPUVirtualAddress());
mUsedBuffers.insert(buffer);
}
void ComputeQueueDX12::ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc)
{
if (inDst == nullptr || inSrc == nullptr)
return;
JPH_ASSERT(inDst->GetType() == ComputeBuffer::EType::ReadbackBuffer);
ID3D12GraphicsCommandList *command_list = Start();
ComputeBufferDX12 *dst = static_cast<ComputeBufferDX12 *>(inDst);
const ComputeBufferDX12 *src = static_cast<const ComputeBufferDX12 *>(inSrc);
dst->Barrier(command_list, D3D12_RESOURCE_STATE_COPY_DEST);
src->Barrier(command_list, D3D12_RESOURCE_STATE_COPY_SOURCE);
command_list->CopyResource(dst->GetResourceCPU(), src->GetResourceGPU());
mUsedBuffers.insert(src);
mUsedBuffers.insert(dst);
}
void ComputeQueueDX12::Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ)
{
ID3D12GraphicsCommandList *command_list = Start();
command_list->Dispatch(inThreadGroupsX, inThreadGroupsY, inThreadGroupsZ);
}
void ComputeQueueDX12::Execute()
{
JPH_ASSERT(mIsStarted);
JPH_ASSERT(!mIsExecuting);
// Close the command list
if (HRFailed(mCommandList->Close()))
return;
// Execute the command list
ID3D12CommandList *command_lists[] = { mCommandList.Get() };
mCommandQueue->ExecuteCommandLists((UINT)std::size(command_lists), command_lists);
// Schedule a Signal command in the queue
if (HRFailed(mCommandQueue->Signal(mFence.Get(), mFenceValue)))
return;
// Clear the current shader
mShader = nullptr;
// Mark that we're executing
mIsExecuting = true;
}
void ComputeQueueDX12::Wait()
{
// Check if we've been started
if (mIsExecuting)
{
if (mFence->GetCompletedValue() < mFenceValue)
{
// Wait until the fence has been processed
if (HRFailed(mFence->SetEventOnCompletion(mFenceValue, mFenceEvent)))
return;
WaitForSingleObjectEx(mFenceEvent, INFINITE, FALSE);
}
// Increment the fence value
mFenceValue++;
// Buffers can be freed now
mUsedBuffers.clear();
// Free buffers
mDelayedFreedBuffers.clear();
// Done executing
mIsExecuting = false;
mIsStarted = false;
}
}
JPH_NAMESPACE_END
#endif // JPH_USE_DX12

View File

@ -0,0 +1,61 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#ifdef JPH_USE_DX12
#include <Jolt/Compute/ComputeQueue.h>
#include <Jolt/Compute/DX12/ComputeShaderDX12.h>
#include <Jolt/Core/UnorderedSet.h>
JPH_NAMESPACE_BEGIN
class ComputeBufferDX12;
/// A command queue for DirectX for executing compute workloads on the GPU.
class JPH_EXPORT ComputeQueueDX12 final : public ComputeQueue
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Destructor
virtual ~ComputeQueueDX12() override;
/// Initialize the queue
bool Initialize(ID3D12Device *inDevice, D3D12_COMMAND_LIST_TYPE inType, ComputeQueueResult &outResult);
/// Start the command list (requires waiting until the previous one is finished)
ID3D12GraphicsCommandList * Start();
// See: ComputeQueue
virtual void SetShader(const ComputeShader *inShader) override;
virtual void SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
virtual void SetBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
virtual void SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier = EBarrier::Yes) override;
virtual void ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc) override;
virtual void Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ) override;
virtual void Execute() override;
virtual void Wait() override;
private:
/// Copy the CPU buffer to the GPU buffer if needed
void SyncCPUToGPU(const ComputeBufferDX12 *inBuffer);
ComPtr<ID3D12CommandQueue> mCommandQueue; ///< The command queue that will hold command lists
ComPtr<ID3D12CommandAllocator> mCommandAllocator; ///< Allocator that holds the memory for the commands
ComPtr<ID3D12GraphicsCommandList> mCommandList; ///< The command list that will hold the render commands / state changes
HANDLE mFenceEvent = INVALID_HANDLE_VALUE; ///< Fence event, used to wait for rendering to complete
ComPtr<ID3D12Fence> mFence; ///< Fence object, used to signal the fence event
UINT64 mFenceValue = 0; ///< Current fence value, each time we need to wait we will signal the fence with this value, wait for it and then increase the value
RefConst<ComputeShaderDX12> mShader = nullptr; ///< Current active shader
bool mIsStarted = false; ///< If the command list has been started (reset) and is ready to record commands
bool mIsExecuting = false; ///< If a command list is currently executing on the queue
UnorderedSet<RefConst<ComputeBuffer>> mUsedBuffers; ///< Buffers that are in use by the current execution, these will be retained until execution is finished so that we don't free buffers that are in use
Array<ComPtr<ID3D12Resource>> mDelayedFreedBuffers; ///< Buffers freed during the execution
};
JPH_NAMESPACE_END
#endif // JPH_USE_DX12

View File

@ -0,0 +1,54 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#ifdef JPH_USE_DX12
#include <Jolt/Compute/ComputeShader.h>
#include <Jolt/Compute/DX12/IncludeDX12.h>
#include <Jolt/Core/UnorderedMap.h>
JPH_NAMESPACE_BEGIN
/// Compute shader handle for DirectX
class JPH_EXPORT ComputeShaderDX12 : public ComputeShader
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Constructor
ComputeShaderDX12(ComPtr<ID3DBlob> inShader, ComPtr<ID3D12RootSignature> inRootSignature, ComPtr<ID3D12PipelineState> inPipelineState, Array<String> &&inBindingNames, UnorderedMap<string_view, uint> &&inNameToIndex, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) :
ComputeShader(inGroupSizeX, inGroupSizeY, inGroupSizeZ),
mShader(inShader),
mRootSignature(inRootSignature),
mPipelineState(inPipelineState),
mBindingNames(std::move(inBindingNames)),
mNameToIndex(std::move(inNameToIndex))
{
}
/// Get index of shader parameter
uint NameToIndex(const char *inName) const
{
UnorderedMap<string_view, uint>::const_iterator it = mNameToIndex.find(inName);
JPH_ASSERT(it != mNameToIndex.end());
return it->second;
}
/// Getters
ID3D12PipelineState * GetPipelineState() const { return mPipelineState.Get(); }
ID3D12RootSignature * GetRootSignature() const { return mRootSignature.Get(); }
private:
ComPtr<ID3DBlob> mShader; ///< The compiled shader
ComPtr<ID3D12RootSignature> mRootSignature; ///< The root signature for this shader
ComPtr<ID3D12PipelineState> mPipelineState; ///< The pipeline state object for this shader
Array<String> mBindingNames; ///< A list of binding names, mNameToIndex points to these strings
UnorderedMap<string_view, uint> mNameToIndex; ///< Maps names to indices for the shader parameters, using a string_view so we can do find() without an allocation
};
JPH_NAMESPACE_END
#endif // JPH_USE_DX12

View File

@ -0,0 +1,443 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_DX12
#include <Jolt/Compute/DX12/ComputeSystemDX12.h>
#include <Jolt/Compute/DX12/ComputeQueueDX12.h>
#include <Jolt/Compute/DX12/ComputeShaderDX12.h>
#include <Jolt/Compute/DX12/ComputeBufferDX12.h>
#include <Jolt/Core/StringTools.h>
#include <Jolt/Core/UnorderedMap.h>
JPH_SUPPRESS_WARNINGS_STD_BEGIN
JPH_MSVC_SUPPRESS_WARNING(5204) // 'X': class has virtual functions, but its trivial destructor is not virtual; instances of objects derived from this class may not be destructed correctly
JPH_MSVC2026_PLUS_SUPPRESS_WARNING(4865) // wingdi.h(2806,1): '<unnamed-enum-DISPLAYCONFIG_OUTPUT_TECHNOLOGY_OTHER>': the underlying type will change from 'int' to '__int64' when '/Zc:enumTypes' is specified on the command line
#include <fstream>
#include <d3dcompiler.h>
#include <dxcapi.h>
#ifdef JPH_DEBUG
#include <d3d12sdklayers.h>
#endif
JPH_SUPPRESS_WARNINGS_STD_END
JPH_NAMESPACE_BEGIN
JPH_IMPLEMENT_RTTI_VIRTUAL(ComputeSystemDX12)
{
JPH_ADD_BASE_CLASS(ComputeSystemDX12, ComputeSystem)
}
void ComputeSystemDX12::Initialize(ID3D12Device *inDevice, EDebug inDebug)
{
mDevice = inDevice;
mDebug = inDebug;
}
void ComputeSystemDX12::Shutdown()
{
mDevice.Reset();
}
ComPtr<ID3D12Resource> ComputeSystemDX12::CreateD3DResource(D3D12_HEAP_TYPE inHeapType, D3D12_RESOURCE_STATES inResourceState, D3D12_RESOURCE_FLAGS inFlags, uint64 inSize)
{
// Create a new resource
D3D12_RESOURCE_DESC desc;
desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
desc.Alignment = 0;
desc.Width = inSize;
desc.Height = 1;
desc.DepthOrArraySize = 1;
desc.MipLevels = 1;
desc.Format = DXGI_FORMAT_UNKNOWN;
desc.SampleDesc.Count = 1;
desc.SampleDesc.Quality = 0;
desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
desc.Flags = inFlags;
D3D12_HEAP_PROPERTIES heap_properties = {};
heap_properties.Type = inHeapType;
heap_properties.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN;
heap_properties.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN;
heap_properties.CreationNodeMask = 1;
heap_properties.VisibleNodeMask = 1;
ComPtr<ID3D12Resource> resource;
if (HRFailed(mDevice->CreateCommittedResource(&heap_properties, D3D12_HEAP_FLAG_NONE, &desc, inResourceState, nullptr, IID_PPV_ARGS(&resource))))
return nullptr;
return resource;
}
ComputeShaderResult ComputeSystemDX12::CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ)
{
ComputeShaderResult result;
// Read shader source file
Array<uint8> data;
String error;
String file_name = String(inName) + ".hlsl";
if (!mShaderLoader(file_name.c_str(), data, error))
{
result.SetError(error);
return result;
}
#ifndef JPH_USE_DXC // Use FXC, the old shader compiler?
UINT flags = D3DCOMPILE_ENABLE_STRICTNESS | D3DCOMPILE_WARNINGS_ARE_ERRORS | D3DCOMPILE_ALL_RESOURCES_BOUND;
#ifdef JPH_DEBUG
flags |= D3DCOMPILE_SKIP_OPTIMIZATION;
#else
flags |= D3DCOMPILE_OPTIMIZATION_LEVEL3;
#endif
if (mDebug == EDebug::DebugSymbols)
flags |= D3DCOMPILE_DEBUG;
const D3D_SHADER_MACRO defines[] =
{
{ nullptr, nullptr }
};
// Handles loading include files through the shader loader
struct IncludeHandler : public ID3DInclude
{
IncludeHandler(const ShaderLoader &inShaderLoader) : mShaderLoader(inShaderLoader) { }
virtual ~IncludeHandler() = default;
STDMETHOD (Open)(D3D_INCLUDE_TYPE, LPCSTR inFileName, LPCVOID, LPCVOID *outData, UINT *outNumBytes) override
{
// Read the header file
Array<uint8> file_data;
String error;
if (!mShaderLoader(inFileName, file_data, error))
return E_FAIL;
if (file_data.empty())
{
*outData = nullptr;
*outNumBytes = 0;
return S_OK;
}
// Copy to a new memory block
void *mem = CoTaskMemAlloc(file_data.size());
if (mem == nullptr)
return E_OUTOFMEMORY;
memcpy(mem, file_data.data(), file_data.size());
*outData = mem;
*outNumBytes = (UINT)file_data.size();
return S_OK;
}
STDMETHOD (Close)(LPCVOID inData) override
{
if (inData != nullptr)
CoTaskMemFree(const_cast<void *>(inData));
return S_OK;
}
private:
const ShaderLoader & mShaderLoader;
};
IncludeHandler include_handler(mShaderLoader);
// Compile source
ComPtr<ID3DBlob> shader_blob, error_blob;
if (FAILED(D3DCompile(&data[0],
(uint)data.size(),
file_name.c_str(),
defines,
&include_handler,
"main",
"cs_5_0",
flags,
0,
shader_blob.GetAddressOf(),
error_blob.GetAddressOf())))
{
if (error_blob)
result.SetError((const char *)error_blob->GetBufferPointer());
else
result.SetError("Shader compile error");
return result;
}
// Get shader description
ComPtr<ID3D12ShaderReflection> reflector;
if (FAILED(D3DReflect(shader_blob->GetBufferPointer(), shader_blob->GetBufferSize(), IID_PPV_ARGS(&reflector))))
{
result.SetError("Failed to reflect shader");
return result;
}
#else
ComPtr<IDxcUtils> utils;
DxcCreateInstance(CLSID_DxcUtils, IID_PPV_ARGS(utils.GetAddressOf()));
// Custom include handler that forwards include loads to mShaderLoader
struct DxcIncludeHandler : public IDxcIncludeHandler
{
DxcIncludeHandler(IDxcUtils *inUtils, const ShaderLoader &inLoader) : mUtils(inUtils), mShaderLoader(inLoader) { }
virtual ~DxcIncludeHandler() = default;
STDMETHODIMP QueryInterface(REFIID riid, void **ppvObject) override
{
JPH_ASSERT(false);
return E_NOINTERFACE;
}
STDMETHODIMP_(ULONG) AddRef(void) override
{
// Allocated on the stack, we don't do ref counting
return 1;
}
STDMETHODIMP_(ULONG) Release(void) override
{
// Allocated on the stack, we don't do ref counting
return 1;
}
// IDxcIncludeHandler::LoadSource uses IDxcBlob**
STDMETHODIMP LoadSource(LPCWSTR inFilename, IDxcBlob **outIncludeSource) override
{
*outIncludeSource = nullptr;
// Convert to UTF-8
char file_name[MAX_PATH];
WideCharToMultiByte(CP_UTF8, 0, inFilename, -1, file_name, sizeof(file_name), nullptr, nullptr);
// Load the header
Array<uint8> file_data;
String error;
if (!mShaderLoader(file_name, file_data, error))
return E_FAIL;
// Create a blob from the loaded data
ComPtr<IDxcBlobEncoding> blob_encoder;
HRESULT hr = mUtils->CreateBlob(file_data.empty()? nullptr : file_data.data(), (uint)file_data.size(), CP_UTF8, blob_encoder.GetAddressOf());
if (FAILED(hr))
return hr;
// Return as IDxcBlob
*outIncludeSource = blob_encoder.Detach();
return S_OK;
}
IDxcUtils * mUtils;
const ShaderLoader & mShaderLoader;
};
DxcIncludeHandler include_handler(utils.Get(), mShaderLoader);
ComPtr<IDxcBlobEncoding> source;
if (HRFailed(utils->CreateBlob(data.data(), (uint)data.size(), CP_UTF8, source.GetAddressOf()), result))
return result;
ComPtr<IDxcCompiler3> compiler;
DxcCreateInstance(CLSID_DxcCompiler, IID_PPV_ARGS(compiler.GetAddressOf()));
Array<LPCWSTR> arguments;
arguments.push_back(L"-E");
arguments.push_back(L"main");
arguments.push_back(L"-T");
arguments.push_back(L"cs_6_0");
arguments.push_back(DXC_ARG_WARNINGS_ARE_ERRORS);
arguments.push_back(DXC_ARG_OPTIMIZATION_LEVEL3);
arguments.push_back(DXC_ARG_ALL_RESOURCES_BOUND);
if (mDebug == EDebug::DebugSymbols)
{
arguments.push_back(DXC_ARG_DEBUG);
arguments.push_back(L"-Qembed_debug");
}
// Provide file name so tools know what the original shader was called (the actual source comes from the blob)
wchar_t w_file_name[MAX_PATH];
MultiByteToWideChar(CP_UTF8, 0, file_name.c_str(), -1, w_file_name, MAX_PATH);
arguments.push_back(w_file_name);
// Compile the shader
DxcBuffer source_buffer;
source_buffer.Ptr = source->GetBufferPointer();
source_buffer.Size = source->GetBufferSize();
source_buffer.Encoding = 0;
ComPtr<IDxcResult> compile_result;
if (FAILED(compiler->Compile(&source_buffer, arguments.data(), (uint32)arguments.size(), &include_handler, IID_PPV_ARGS(compile_result.GetAddressOf()))))
{
result.SetError("Failed to compile shader");
return result;
}
// Check for compilation errors
ComPtr<IDxcBlobUtf8> errors;
compile_result->GetOutput(DXC_OUT_ERRORS, IID_PPV_ARGS(errors.GetAddressOf()), nullptr);
if (errors != nullptr && errors->GetStringLength() > 0)
{
result.SetError((const char *)errors->GetBufferPointer());
return result;
}
// Get the compiled shader code
ComPtr<ID3DBlob> shader_blob;
if (HRFailed(compile_result->GetOutput(DXC_OUT_OBJECT, IID_PPV_ARGS(shader_blob.GetAddressOf()), nullptr), result))
return result;
// Get reflection data
ComPtr<IDxcBlob> reflection_data;
if (HRFailed(compile_result->GetOutput(DXC_OUT_REFLECTION, IID_PPV_ARGS(reflection_data.GetAddressOf()), nullptr), result))
return result;
DxcBuffer reflection_buffer;
reflection_buffer.Ptr = reflection_data->GetBufferPointer();
reflection_buffer.Size = reflection_data->GetBufferSize();
reflection_buffer.Encoding = 0;
ComPtr<ID3D12ShaderReflection> reflector;
if (HRFailed(utils->CreateReflection(&reflection_buffer, IID_PPV_ARGS(reflector.GetAddressOf())), result))
return result;
#endif // JPH_USE_DXC
// Get the shader description
D3D12_SHADER_DESC shader_desc;
if (HRFailed(reflector->GetDesc(&shader_desc), result))
return result;
// Verify that the group sizes match the shader's thread group size
UINT thread_group_size_x, thread_group_size_y, thread_group_size_z;
if (HRFailed(reflector->GetThreadGroupSize(&thread_group_size_x, &thread_group_size_y, &thread_group_size_z), result))
return result;
JPH_ASSERT(inGroupSizeX == thread_group_size_x, "Group size X mismatch");
JPH_ASSERT(inGroupSizeY == thread_group_size_y, "Group size Y mismatch");
JPH_ASSERT(inGroupSizeZ == thread_group_size_z, "Group size Z mismatch");
// Convert parameters to root signature description
Array<String> binding_names;
binding_names.reserve(shader_desc.BoundResources);
UnorderedMap<string_view, uint> name_to_index;
Array<D3D12_ROOT_PARAMETER1> root_params;
for (UINT i = 0; i < shader_desc.BoundResources; ++i)
{
D3D12_SHADER_INPUT_BIND_DESC bind_desc;
reflector->GetResourceBindingDesc(i, &bind_desc);
D3D12_ROOT_PARAMETER1 param = {};
param.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
switch (bind_desc.Type)
{
case D3D_SIT_CBUFFER:
param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_CBV;
break;
case D3D_SIT_STRUCTURED:
case D3D_SIT_BYTEADDRESS:
param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_SRV;
break;
case D3D_SIT_UAV_RWTYPED:
case D3D_SIT_UAV_RWSTRUCTURED:
case D3D_SIT_UAV_RWBYTEADDRESS:
case D3D_SIT_UAV_APPEND_STRUCTURED:
case D3D_SIT_UAV_CONSUME_STRUCTURED:
case D3D_SIT_UAV_RWSTRUCTURED_WITH_COUNTER:
param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_UAV;
break;
case D3D_SIT_TBUFFER:
case D3D_SIT_TEXTURE:
case D3D_SIT_SAMPLER:
case D3D_SIT_RTACCELERATIONSTRUCTURE:
case D3D_SIT_UAV_FEEDBACKTEXTURE:
JPH_ASSERT(false, "Unsupported shader input type");
continue;
}
param.Descriptor.RegisterSpace = bind_desc.Space;
param.Descriptor.ShaderRegister = bind_desc.BindPoint;
param.Descriptor.Flags = D3D12_ROOT_DESCRIPTOR_FLAG_DATA_VOLATILE;
binding_names.push_back(bind_desc.Name); // Add all strings to a pool to keep them alive
name_to_index[string_view(binding_names.back())] = (uint)root_params.size();
root_params.push_back(param);
}
// Create the root signature
D3D12_VERSIONED_ROOT_SIGNATURE_DESC root_sig_desc = {};
root_sig_desc.Version = D3D_ROOT_SIGNATURE_VERSION_1_1;
root_sig_desc.Desc_1_1.NumParameters = (UINT)root_params.size();
root_sig_desc.Desc_1_1.pParameters = root_params.data();
root_sig_desc.Desc_1_1.NumStaticSamplers = 0;
root_sig_desc.Desc_1_1.pStaticSamplers = nullptr;
root_sig_desc.Desc_1_1.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
ComPtr<ID3DBlob> serialized_sig;
ComPtr<ID3DBlob> root_sig_error_blob;
if (FAILED(D3D12SerializeVersionedRootSignature(&root_sig_desc, &serialized_sig, &root_sig_error_blob)))
{
if (root_sig_error_blob)
{
error = StringFormat("Failed to create root signature: %s", (const char *)root_sig_error_blob->GetBufferPointer());
result.SetError(error);
}
else
result.SetError("Failed to create root signature");
return result;
}
ComPtr<ID3D12RootSignature> root_sig;
if (FAILED(mDevice->CreateRootSignature(0, serialized_sig->GetBufferPointer(), serialized_sig->GetBufferSize(), IID_PPV_ARGS(&root_sig))))
{
result.SetError("Failed to create root signature");
return result;
}
// Create a pipeline state object from the root signature and the shader
ComPtr<ID3D12PipelineState> pipeline_state;
D3D12_COMPUTE_PIPELINE_STATE_DESC compute_state_desc = {};
compute_state_desc.pRootSignature = root_sig.Get();
compute_state_desc.CS = { shader_blob->GetBufferPointer(), shader_blob->GetBufferSize() };
if (FAILED(mDevice->CreateComputePipelineState(&compute_state_desc, IID_PPV_ARGS(&pipeline_state))))
{
result.SetError("Failed to create compute pipeline state");
return result;
}
// Set name on DX12 objects for easier debugging
wchar_t w_name[1024];
size_t converted_chars = 0;
mbstowcs_s(&converted_chars, w_name, 1024, inName, _TRUNCATE);
pipeline_state->SetName(w_name);
result.Set(new ComputeShaderDX12(shader_blob, root_sig, pipeline_state, std::move(binding_names), std::move(name_to_index), inGroupSizeX, inGroupSizeY, inGroupSizeZ));
return result;
}
ComputeBufferResult ComputeSystemDX12::CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData)
{
ComputeBufferResult result;
Ref<ComputeBufferDX12> buffer = new ComputeBufferDX12(this, inType, inSize, inStride);
if (!buffer->Initialize(inData))
{
result.SetError("Failed to create compute buffer");
return result;
}
result.Set(buffer.GetPtr());
return result;
}
ComputeQueueResult ComputeSystemDX12::CreateComputeQueue()
{
ComputeQueueResult result;
Ref<ComputeQueueDX12> queue = new ComputeQueueDX12();
if (!queue->Initialize(mDevice.Get(), D3D12_COMMAND_LIST_TYPE_COMPUTE, result))
return result;
result.Set(queue.GetPtr());
return result;
}
JPH_NAMESPACE_END
#endif // JPH_USE_DX12

View File

@ -0,0 +1,52 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Core/UnorderedMap.h>
#include <Jolt/Compute/ComputeSystem.h>
#ifdef JPH_USE_DX12
#include <Jolt/Compute/DX12/IncludeDX12.h>
JPH_NAMESPACE_BEGIN
/// Interface to run a workload on the GPU using DirectX 12.
/// Minimal implementation that can integrate with your own DirectX 12 setup.
class JPH_EXPORT ComputeSystemDX12 : public ComputeSystem
{
public:
JPH_DECLARE_RTTI_VIRTUAL(JPH_EXPORT, ComputeSystemDX12)
/// How we want to compile our shaders
enum class EDebug
{
NoDebugSymbols,
DebugSymbols
};
/// Initialize / shutdown
void Initialize(ID3D12Device *inDevice, EDebug inDebug);
void Shutdown();
// See: ComputeSystem
virtual ComputeShaderResult CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) override;
virtual ComputeBufferResult CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData = nullptr) override;
virtual ComputeQueueResult CreateComputeQueue() override;
/// Access to the DX12 device
ID3D12Device * GetDevice() const { return mDevice.Get(); }
// Function to create a ID3D12Resource on specified heap with specified state
ComPtr<ID3D12Resource> CreateD3DResource(D3D12_HEAP_TYPE inHeapType, D3D12_RESOURCE_STATES inResourceState, D3D12_RESOURCE_FLAGS inFlags, uint64 inSize);
private:
ComPtr<ID3D12Device> mDevice;
EDebug mDebug = EDebug::NoDebugSymbols;
};
JPH_NAMESPACE_END
#endif // JPH_USE_DX12

View File

@ -0,0 +1,154 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_DX12
#include <Jolt/Compute/DX12/ComputeSystemDX12Impl.h>
#ifdef JPH_DEBUG
#include <d3d12sdklayers.h>
#endif
JPH_NAMESPACE_BEGIN
JPH_IMPLEMENT_RTTI_VIRTUAL(ComputeSystemDX12Impl)
{
JPH_ADD_BASE_CLASS(ComputeSystemDX12Impl, ComputeSystemDX12)
}
ComputeSystemDX12Impl::~ComputeSystemDX12Impl()
{
Shutdown();
mDXGIFactory.Reset();
#ifdef JPH_DEBUG
// Test for leaks
ComPtr<IDXGIDebug1> dxgi_debug;
if (SUCCEEDED(DXGIGetDebugInterface1(0, IID_PPV_ARGS(&dxgi_debug))))
dxgi_debug->ReportLiveObjects(DXGI_DEBUG_ALL, DXGI_DEBUG_RLO_ALL);
#endif
}
bool ComputeSystemDX12Impl::Initialize(ComputeSystemResult &outResult)
{
#if defined(JPH_DEBUG)
// Enable the D3D12 debug layer
ComPtr<ID3D12Debug> debug_controller;
if (SUCCEEDED(D3D12GetDebugInterface(IID_PPV_ARGS(&debug_controller))))
debug_controller->EnableDebugLayer();
#endif
// Create DXGI factory
if (HRFailed(CreateDXGIFactory1(IID_PPV_ARGS(&mDXGIFactory)), outResult))
return false;
// Find adapter
ComPtr<IDXGIAdapter1> adapter;
ComPtr<ID3D12Device> device;
HRESULT result = E_FAIL;
// First check if we have the Windows 1803 IDXGIFactory6 interface
ComPtr<IDXGIFactory6> factory6;
if (SUCCEEDED(mDXGIFactory->QueryInterface(IID_PPV_ARGS(&factory6))))
{
for (int search_software = 0; search_software < 2 && device == nullptr; ++search_software)
for (UINT index = 0; factory6->EnumAdapterByGpuPreference(index, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, IID_PPV_ARGS(&adapter)) != DXGI_ERROR_NOT_FOUND; ++index)
{
DXGI_ADAPTER_DESC1 desc;
adapter->GetDesc1(&desc);
// We don't want software renderers in the first pass
int is_software = (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) != 0? 1 : 0;
if (search_software != is_software)
continue;
// Check to see whether the adapter supports Direct3D 12
#if defined(JPH_PLATFORM_WINDOWS) && defined(_DEBUG)
int prev_state = _CrtSetDbgFlag(0); // Temporarily disable leak detection as this call reports false positives
#endif
result = D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&device));
#if defined(JPH_PLATFORM_WINDOWS) && defined(_DEBUG)
_CrtSetDbgFlag(prev_state);
#endif
if (SUCCEEDED(result))
break;
}
}
else
{
// Fall back to the older method that may not get the fastest GPU
for (int search_software = 0; search_software < 2 && device == nullptr; ++search_software)
for (UINT index = 0; mDXGIFactory->EnumAdapters1(index, &adapter) != DXGI_ERROR_NOT_FOUND; ++index)
{
DXGI_ADAPTER_DESC1 desc;
adapter->GetDesc1(&desc);
// We don't want software renderers in the first pass
int is_software = (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) != 0? 1 : 0;
if (search_software != is_software)
continue;
// Check to see whether the adapter supports Direct3D 12
#if defined(JPH_PLATFORM_WINDOWS) && defined(_DEBUG)
int prev_state = _CrtSetDbgFlag(0); // Temporarily disable leak detection as this call reports false positives
#endif
result = D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&device));
#if defined(JPH_PLATFORM_WINDOWS) && defined(_DEBUG)
_CrtSetDbgFlag(prev_state);
#endif
if (SUCCEEDED(result))
break;
}
}
// Check if we managed to obtain a device
if (HRFailed(result, outResult))
return false;
// Initialize the compute interface
ComputeSystemDX12::Initialize(device.Get(), EDebug::DebugSymbols);
#ifdef JPH_DEBUG
// Enable breaking on errors
ComPtr<ID3D12InfoQueue> info_queue;
if (SUCCEEDED(device.As(&info_queue)))
{
info_queue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_CORRUPTION, TRUE);
info_queue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_ERROR, TRUE);
info_queue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_WARNING, TRUE);
// Disable an error that triggers on Windows 11 with a hybrid graphic system
// See: https://stackoverflow.com/questions/69805245/directx-12-application-is-crashing-in-windows-11
D3D12_MESSAGE_ID hide[] =
{
D3D12_MESSAGE_ID_RESOURCE_BARRIER_MISMATCHING_COMMAND_LIST_TYPE,
};
D3D12_INFO_QUEUE_FILTER filter = { };
filter.DenyList.NumIDs = static_cast<UINT>(std::size(hide));
filter.DenyList.pIDList = hide;
info_queue->AddStorageFilterEntries(&filter);
}
#endif // JPH_DEBUG
return true;
}
ComputeSystemResult CreateComputeSystemDX12()
{
ComputeSystemResult result;
Ref<ComputeSystemDX12Impl> compute = new ComputeSystemDX12Impl();
if (!compute->Initialize(result))
return result;
result.Set(compute.GetPtr());
return result;
}
JPH_NAMESPACE_END
#endif // JPH_USE_DX12

View File

@ -0,0 +1,33 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#ifdef JPH_USE_DX12
#include <Jolt/Compute/DX12/ComputeSystemDX12.h>
JPH_NAMESPACE_BEGIN
/// Implementation of ComputeSystemDX12 that fully initializes DirectX 12
class JPH_EXPORT ComputeSystemDX12Impl : public ComputeSystemDX12
{
public:
JPH_DECLARE_RTTI_VIRTUAL(JPH_EXPORT, ComputeSystemDX12Impl)
/// Destructor
virtual ~ComputeSystemDX12Impl() override;
/// Initialize the compute system
bool Initialize(ComputeSystemResult &outResult);
IDXGIFactory4 * GetDXGIFactory() const { return mDXGIFactory.Get(); }
private:
ComPtr<IDXGIFactory4> mDXGIFactory;
};
JPH_NAMESPACE_END
#endif // JPH_USE_DX12

View File

@ -0,0 +1,49 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Core/IncludeWindows.h>
#include <Jolt/Core/StringTools.h>
JPH_SUPPRESS_WARNINGS_STD_BEGIN
JPH_MSVC_SUPPRESS_WARNING(4265) // 'X': class has virtual functions, but its non-trivial destructor is not virtual; instances of this class may not be destructed correctly
JPH_MSVC_SUPPRESS_WARNING(4625) // 'X': copy constructor was implicitly defined as deleted
JPH_MSVC_SUPPRESS_WARNING(4626) // 'X': assignment operator was implicitly defined as deleted
JPH_MSVC_SUPPRESS_WARNING(5204) // 'X': class has virtual functions, but its trivial destructor is not virtual; instances of objects derived from this class may not be destructed correctly
JPH_MSVC_SUPPRESS_WARNING(5220) // 'X': a non-static data member with a volatile qualified type no longer implies
JPH_MSVC2026_PLUS_SUPPRESS_WARNING(4865) // wingdi.h(2806,1): '<unnamed-enum-DISPLAYCONFIG_OUTPUT_TECHNOLOGY_OTHER>': the underlying type will change from 'int' to '__int64' when '/Zc:enumTypes' is specified on the command line
#include <d3d12.h>
#include <dxgi1_6.h>
#include <dxgidebug.h>
#include <wrl.h>
JPH_SUPPRESS_WARNINGS_STD_END
JPH_NAMESPACE_BEGIN
using Microsoft::WRL::ComPtr;
template <class Result>
inline bool HRFailed(HRESULT inHR, Result &outResult)
{
if (SUCCEEDED(inHR))
return false;
String error = StringFormat("Call failed with error code: %08X", inHR);
outResult.SetError(error);
JPH_ASSERT(false);
return true;
}
inline bool HRFailed(HRESULT inHR)
{
if (SUCCEEDED(inHR))
return false;
Trace("Call failed with error code: %08X", inHR);
JPH_ASSERT(false);
return true;
}
JPH_NAMESPACE_END

View File

@ -0,0 +1,39 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#ifdef JPH_USE_MTL
#include <Jolt/Compute/MTL/ComputeSystemMTL.h>
JPH_NAMESPACE_BEGIN
/// Buffer that can be read from / written to by a compute shader
class JPH_EXPORT ComputeBufferMTL final : public ComputeBuffer
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Constructor
ComputeBufferMTL(ComputeSystemMTL *inComputeSystem, EType inType, uint64 inSize, uint inStride);
virtual ~ComputeBufferMTL() override;
bool Initialize(const void *inData);
virtual ComputeBufferResult CreateReadBackBuffer() const override;
id<MTLBuffer> GetBuffer() const { return mBuffer; }
private:
virtual void * MapInternal(EMode inMode) override;
virtual void UnmapInternal() override;
ComputeSystemMTL * mComputeSystem;
id<MTLBuffer> mBuffer;
};
JPH_NAMESPACE_END
#endif // JPH_USE_MTL

View File

@ -0,0 +1,52 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_MTL
#include <Jolt/Compute/MTL/ComputeBufferMTL.h>
JPH_NAMESPACE_BEGIN
ComputeBufferMTL::ComputeBufferMTL(ComputeSystemMTL *inComputeSystem, EType inType, uint64 inSize, uint inStride) :
ComputeBuffer(inType, inSize, inStride),
mComputeSystem(inComputeSystem)
{
}
bool ComputeBufferMTL::Initialize(const void *inData)
{
NSUInteger size = NSUInteger(mSize) * mStride;
if (inData != nullptr)
mBuffer = [mComputeSystem->GetDevice() newBufferWithBytes: inData length: size options: MTLResourceCPUCacheModeDefaultCache | MTLResourceStorageModeShared | MTLResourceHazardTrackingModeTracked];
else
mBuffer = [mComputeSystem->GetDevice() newBufferWithLength: size options: MTLResourceCPUCacheModeDefaultCache | MTLResourceStorageModeShared | MTLResourceHazardTrackingModeTracked];
return mBuffer != nil;
}
ComputeBufferMTL::~ComputeBufferMTL()
{
[mBuffer release];
}
void *ComputeBufferMTL::MapInternal(EMode inMode)
{
return mBuffer.contents;
}
void ComputeBufferMTL::UnmapInternal()
{
}
ComputeBufferResult ComputeBufferMTL::CreateReadBackBuffer() const
{
ComputeBufferResult result;
result.Set(const_cast<ComputeBufferMTL *>(this));
return result;
}
JPH_NAMESPACE_END
#endif // JPH_USE_MTL

View File

@ -0,0 +1,49 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#ifdef JPH_USE_MTL
#include <MetalKit/MetalKit.h>
#include <Jolt/Compute/ComputeQueue.h>
JPH_NAMESPACE_BEGIN
class ComputeShaderMTL;
/// A command queue for Metal for executing compute workloads on the GPU.
class JPH_EXPORT ComputeQueueMTL final : public ComputeQueue
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Constructor / destructor
ComputeQueueMTL(id<MTLDevice> inDevice);
virtual ~ComputeQueueMTL() override;
// See: ComputeQueue
virtual void SetShader(const ComputeShader *inShader) override;
virtual void SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
virtual void SetBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
virtual void SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier = EBarrier::Yes) override;
virtual void ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc) override;
virtual void Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ) override;
virtual void Execute() override;
virtual void Wait() override;
private:
void BeginCommandBuffer();
id<MTLCommandQueue> mCommandQueue;
id<MTLCommandBuffer> mCommandBuffer;
id<MTLComputeCommandEncoder> mComputeEncoder;
RefConst<ComputeShaderMTL> mShader;
bool mIsExecuting = false;
};
JPH_NAMESPACE_END
#endif // JPH_USE_MTL

View File

@ -0,0 +1,123 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_MTL
#include <Jolt/Compute/MTL/ComputeQueueMTL.h>
#include <Jolt/Compute/MTL/ComputeShaderMTL.h>
#include <Jolt/Compute/MTL/ComputeBufferMTL.h>
#include <Jolt/Compute/MTL/ComputeSystemMTL.h>
JPH_NAMESPACE_BEGIN
ComputeQueueMTL::~ComputeQueueMTL()
{
Wait();
[mCommandQueue release];
}
ComputeQueueMTL::ComputeQueueMTL(id<MTLDevice> inDevice)
{
// Create the command queue
mCommandQueue = [inDevice newCommandQueue];
}
void ComputeQueueMTL::BeginCommandBuffer()
{
if (mCommandBuffer == nil)
{
// Start a new command buffer
mCommandBuffer = [mCommandQueue commandBuffer];
mComputeEncoder = [mCommandBuffer computeCommandEncoder];
}
}
void ComputeQueueMTL::SetShader(const ComputeShader *inShader)
{
BeginCommandBuffer();
mShader = static_cast<const ComputeShaderMTL *>(inShader);
[mComputeEncoder setComputePipelineState: mShader->GetPipelineState()];
}
void ComputeQueueMTL::SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer)
{
if (inBuffer == nullptr)
return;
JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::ConstantBuffer);
BeginCommandBuffer();
const ComputeBufferMTL *buffer = static_cast<const ComputeBufferMTL *>(inBuffer);
[mComputeEncoder setBuffer: buffer->GetBuffer() offset: 0 atIndex: mShader->NameToBindingIndex(inName)];
}
void ComputeQueueMTL::SetBuffer(const char *inName, const ComputeBuffer *inBuffer)
{
if (inBuffer == nullptr)
return;
JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::UploadBuffer || inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
BeginCommandBuffer();
const ComputeBufferMTL *buffer = static_cast<const ComputeBufferMTL *>(inBuffer);
[mComputeEncoder setBuffer: buffer->GetBuffer() offset: 0 atIndex: mShader->NameToBindingIndex(inName)];
}
void ComputeQueueMTL::SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier)
{
if (inBuffer == nullptr)
return;
JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
BeginCommandBuffer();
const ComputeBufferMTL *buffer = static_cast<const ComputeBufferMTL *>(inBuffer);
[mComputeEncoder setBuffer: buffer->GetBuffer() offset: 0 atIndex: mShader->NameToBindingIndex(inName)];
}
void ComputeQueueMTL::ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc)
{
JPH_ASSERT(inDst == inSrc); // Since ComputeBuffer::CreateReadBackBuffer returns the same buffer, we don't need to copy
}
void ComputeQueueMTL::Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ)
{
BeginCommandBuffer();
MTLSize thread_groups = MTLSizeMake(inThreadGroupsX, inThreadGroupsY, inThreadGroupsZ);
MTLSize group_size = MTLSizeMake(mShader->GetGroupSizeX(), mShader->GetGroupSizeY(), mShader->GetGroupSizeZ());
[mComputeEncoder dispatchThreadgroups: thread_groups threadsPerThreadgroup: group_size];
}
void ComputeQueueMTL::Execute()
{
// End command buffer
if (mCommandBuffer == nil)
return;
[mComputeEncoder endEncoding];
[mCommandBuffer commit];
mShader = nullptr;
mIsExecuting = true;
}
void ComputeQueueMTL::Wait()
{
if (!mIsExecuting)
return;
[mCommandBuffer waitUntilCompleted];
mComputeEncoder = nil;
mCommandBuffer = nil;
mIsExecuting = false;
}
JPH_NAMESPACE_END
#endif // JPH_USE_MTL

View File

@ -0,0 +1,39 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#ifdef JPH_USE_MTL
#include <MetalKit/MetalKit.h>
#include <Jolt/Compute/ComputeShader.h>
#include <Jolt/Core/UnorderedMap.h>
JPH_NAMESPACE_BEGIN
/// Compute shader handle for Metal
class JPH_EXPORT ComputeShaderMTL : public ComputeShader
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Constructor
ComputeShaderMTL(id<MTLComputePipelineState> inPipelineState, MTLComputePipelineReflection *inReflection, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ);
virtual ~ComputeShaderMTL() override { [mPipelineState release]; }
/// Access to the function
id<MTLComputePipelineState> GetPipelineState() const { return mPipelineState; }
/// Get index of buffer name
uint NameToBindingIndex(const char *inName) const;
private:
id<MTLComputePipelineState> mPipelineState;
UnorderedMap<String, uint> mNameToBindingIndex;
};
JPH_NAMESPACE_END
#endif // JPH_USE_MTL

View File

@ -0,0 +1,34 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_MTL
#include <Jolt/Compute/MTL/ComputeShaderMTL.h>
JPH_NAMESPACE_BEGIN
ComputeShaderMTL::ComputeShaderMTL(id<MTLComputePipelineState> inPipelineState, MTLComputePipelineReflection *inReflection, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) :
ComputeShader(inGroupSizeX, inGroupSizeY, inGroupSizeZ),
mPipelineState(inPipelineState)
{
for (id<MTLBinding> binding in inReflection.bindings)
{
const char *name = [binding.name UTF8String];
uint index = uint(binding.index);
mNameToBindingIndex[name] = index;
}
}
uint ComputeShaderMTL::NameToBindingIndex(const char *inName) const
{
UnorderedMap<String, uint>::const_iterator it = mNameToBindingIndex.find(inName);
JPH_ASSERT(it != mNameToBindingIndex.end());
return it->second;
}
JPH_NAMESPACE_END
#endif // JPH_USE_MTL

View File

@ -0,0 +1,40 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Compute/ComputeSystem.h>
#ifdef JPH_USE_MTL
#include <MetalKit/MetalKit.h>
JPH_NAMESPACE_BEGIN
/// Interface to run a workload on the GPU
class JPH_EXPORT ComputeSystemMTL : public ComputeSystem
{
public:
JPH_DECLARE_RTTI_VIRTUAL(JPH_EXPORT, ComputeSystemMTL)
// Initialize / shutdown the compute system
bool Initialize(id<MTLDevice> inDevice);
void Shutdown();
// See: ComputeSystem
virtual ComputeShaderResult CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) override;
virtual ComputeBufferResult CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData = nullptr) override;
virtual ComputeQueueResult CreateComputeQueue() override;
/// Get the metal device
id<MTLDevice> GetDevice() const { return mDevice; }
private:
id<MTLDevice> mDevice;
id<MTLLibrary> mShaderLibrary;
};
JPH_NAMESPACE_END
#endif // JPH_USE_MTL

View File

@ -0,0 +1,110 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_MTL
#include <Jolt/Compute/MTL/ComputeSystemMTL.h>
#include <Jolt/Compute/MTL/ComputeBufferMTL.h>
#include <Jolt/Compute/MTL/ComputeShaderMTL.h>
#include <Jolt/Compute/MTL/ComputeQueueMTL.h>
JPH_NAMESPACE_BEGIN
JPH_IMPLEMENT_RTTI_VIRTUAL(ComputeSystemMTL)
{
JPH_ADD_BASE_CLASS(ComputeSystemMTL, ComputeSystem)
}
bool ComputeSystemMTL::Initialize(id<MTLDevice> inDevice)
{
mDevice = [inDevice retain];
return true;
}
void ComputeSystemMTL::Shutdown()
{
[mShaderLibrary release];
[mDevice release];
}
ComputeShaderResult ComputeSystemMTL::CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ)
{
ComputeShaderResult result;
if (mShaderLibrary == nil)
{
// Load the shader library containing all shaders
Array<uint8> *data = new Array<uint8>();
String error;
if (!mShaderLoader("Jolt.metallib", *data, error))
{
result.SetError(error);
delete data;
return result;
}
// Convert to dispatch data
dispatch_data_t data_dispatch = dispatch_data_create(data->data(), data->size(), nullptr, ^{ delete data; });
// Create the library
NSError *ns_error = nullptr;
mShaderLibrary = [mDevice newLibraryWithData: data_dispatch error: &ns_error];
if (ns_error != nil)
{
result.SetError("Failed to laod shader library");
return result;
}
}
// Get the shader function
id<MTLFunction> function = [mShaderLibrary newFunctionWithName: [NSString stringWithCString: inName encoding: NSUTF8StringEncoding]];
if (function == nil)
{
result.SetError("Failed to instantiate compute shader");
return result;
}
// Create the pipeline
NSError *error = nil;
MTLComputePipelineReflection *reflection = nil;
id<MTLComputePipelineState> pipeline_state = [mDevice newComputePipelineStateWithFunction: function options: MTLPipelineOptionBindingInfo | MTLPipelineOptionBufferTypeInfo reflection: &reflection error: &error];
if (error != nil || pipeline_state == nil)
{
result.SetError("Failed to create compute pipeline");
[function release];
return result;
}
result.Set(new ComputeShaderMTL(pipeline_state, reflection, inGroupSizeX, inGroupSizeY, inGroupSizeZ));
return result;
}
ComputeBufferResult ComputeSystemMTL::CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData)
{
ComputeBufferResult result;
Ref<ComputeBufferMTL> buffer = new ComputeBufferMTL(this, inType, inSize, inStride);
if (!buffer->Initialize(inData))
{
result.SetError("Failed to create compute buffer");
return result;
}
result.Set(buffer.GetPtr());
return result;
}
ComputeQueueResult ComputeSystemMTL::CreateComputeQueue()
{
ComputeQueueResult result;
result.Set(new ComputeQueueMTL(mDevice));
return result;
}
JPH_NAMESPACE_END
#endif // JPH_USE_MTL

View File

@ -0,0 +1,28 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#ifdef JPH_USE_MTL
#include <Jolt/Compute/MTL/ComputeSystemMTL.h>
JPH_NAMESPACE_BEGIN
/// Interface to run a workload on the GPU that fully initializes Metal.
class JPH_EXPORT ComputeSystemMTLImpl : public ComputeSystemMTL
{
public:
JPH_DECLARE_RTTI_VIRTUAL(JPH_EXPORT, ComputeSystemMTLImpl)
/// Destructor
virtual ~ComputeSystemMTLImpl() override;
/// Initialize / shutdown the compute system
bool Initialize();
};
JPH_NAMESPACE_END
#endif // JPH_USE_MTL

View File

@ -0,0 +1,49 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_MTL
#include <Jolt/Compute/MTL/ComputeSystemMTLImpl.h>
JPH_NAMESPACE_BEGIN
JPH_IMPLEMENT_RTTI_VIRTUAL(ComputeSystemMTLImpl)
{
JPH_ADD_BASE_CLASS(ComputeSystemMTLImpl, ComputeSystemMTL)
}
ComputeSystemMTLImpl::~ComputeSystemMTLImpl()
{
Shutdown();
[GetDevice() release];
}
bool ComputeSystemMTLImpl::Initialize()
{
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
return ComputeSystemMTL::Initialize(device);
}
ComputeSystemResult CreateComputeSystemMTL()
{
ComputeSystemResult result;
Ref<ComputeSystemMTLImpl> compute = new ComputeSystemMTLImpl;
if (!compute->Initialize())
{
result.SetError("Failed to initialize compute system");
return result;
}
result.Set(compute.GetPtr());
return result;
}
JPH_NAMESPACE_END
#endif // JPH_USE_MTL

View File

@ -0,0 +1,42 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2024 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Compute/VK/IncludeVK.h>
#include <Jolt/Core/Reference.h>
#include <Jolt/Core/NonCopyable.h>
JPH_NAMESPACE_BEGIN
/// Simple wrapper class to manage a Vulkan memory block
class MemoryVK : public RefTarget<MemoryVK>, public NonCopyable
{
public:
~MemoryVK()
{
// We should have unmapped and freed the block before destruction
JPH_ASSERT(mMappedCount == 0);
JPH_ASSERT(mMemory == VK_NULL_HANDLE);
}
VkDeviceMemory mMemory = VK_NULL_HANDLE; ///< The Vulkan memory handle
VkDeviceSize mSize = 0; ///< Size of the memory block
VkDeviceSize mBufferSize = 0; ///< Size of each of the buffers that this memory block has been divided into
VkMemoryPropertyFlags mProperties = 0; ///< Vulkan memory properties used to allocate this block
int mMappedCount = 0; ///< How often buffers using this memory block were mapped
void * mMappedPtr = nullptr; ///< The CPU address of the memory block when mapped
};
/// Simple wrapper class to manage a Vulkan buffer
class BufferVK
{
public:
Ref<MemoryVK> mMemory; ///< The memory block that contains the buffer (note that filling this in is optional if you do your own buffer allocation)
VkBuffer mBuffer = VK_NULL_HANDLE; ///< The Vulkan buffer handle
VkDeviceSize mOffset = 0; ///< Offset in the memory block where the buffer starts
VkDeviceSize mSize = 0; ///< Real size of the buffer
};
JPH_NAMESPACE_END

View File

@ -0,0 +1,140 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_VK
#include <Jolt/Compute/VK/ComputeBufferVK.h>
#include <Jolt/Compute/VK/ComputeSystemVK.h>
JPH_NAMESPACE_BEGIN
ComputeBufferVK::ComputeBufferVK(ComputeSystemVK *inComputeSystem, EType inType, uint64 inSize, uint inStride) :
ComputeBuffer(inType, inSize, inStride),
mComputeSystem(inComputeSystem)
{
}
bool ComputeBufferVK::Initialize(const void *inData)
{
VkDeviceSize buffer_size = VkDeviceSize(mSize * mStride);
switch (mType)
{
case EType::Buffer:
JPH_ASSERT(inData != nullptr);
[[fallthrough]];
case EType::UploadBuffer:
case EType::RWBuffer:
if (!mComputeSystem->CreateBuffer(buffer_size, VK_BUFFER_USAGE_TRANSFER_SRC_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT, mBufferCPU))
return false;
if (!mComputeSystem->CreateBuffer(buffer_size, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, mBufferGPU))
return false;
if (inData != nullptr)
{
void *data = mComputeSystem->MapBuffer(mBufferCPU);
memcpy(data, inData, size_t(buffer_size));
mComputeSystem->UnmapBuffer(mBufferCPU);
mNeedsSync = true;
}
break;
case EType::ConstantBuffer:
if (!mComputeSystem->CreateBuffer(buffer_size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT, mBufferCPU))
return false;
if (inData != nullptr)
{
void* data = mComputeSystem->MapBuffer(mBufferCPU);
memcpy(data, inData, size_t(buffer_size));
mComputeSystem->UnmapBuffer(mBufferCPU);
}
break;
case EType::ReadbackBuffer:
JPH_ASSERT(inData == nullptr, "Can't upload data to a readback buffer");
if (!mComputeSystem->CreateBuffer(buffer_size, VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT, mBufferCPU))
return false;
break;
}
return true;
}
ComputeBufferVK::~ComputeBufferVK()
{
mComputeSystem->FreeBuffer(mBufferGPU);
mComputeSystem->FreeBuffer(mBufferCPU);
}
void ComputeBufferVK::Barrier(VkCommandBuffer inCommandBuffer, VkPipelineStageFlags inToStage, VkAccessFlagBits inToFlags, bool inForce) const
{
if (mAccessStage == inToStage && mAccessFlagBits == inToFlags && !inForce)
return;
VkBufferMemoryBarrier b = {};
b.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER;
b.srcAccessMask = mAccessFlagBits;
b.dstAccessMask = inToFlags;
b.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
b.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
b.buffer = mBufferGPU.mBuffer != VK_NULL_HANDLE? mBufferGPU.mBuffer : mBufferCPU.mBuffer;
b.offset = 0;
b.size = VK_WHOLE_SIZE;
vkCmdPipelineBarrier(inCommandBuffer, mAccessStage, inToStage, 0, 0, nullptr, 1, &b, 0, nullptr);
mAccessStage = inToStage;
mAccessFlagBits = inToFlags;
}
bool ComputeBufferVK::SyncCPUToGPU(VkCommandBuffer inCommandBuffer) const
{
if (!mNeedsSync)
return false;
// Barrier before write
Barrier(inCommandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_ACCESS_TRANSFER_WRITE_BIT, false);
// Copy from CPU to GPU
VkBufferCopy copy = {};
copy.srcOffset = 0;
copy.dstOffset = 0;
copy.size = GetSize() * GetStride();
vkCmdCopyBuffer(inCommandBuffer, mBufferCPU.mBuffer, mBufferGPU.mBuffer, 1, &copy);
mNeedsSync = false;
return true;
}
void *ComputeBufferVK::MapInternal(EMode inMode)
{
switch (inMode)
{
case EMode::Read:
JPH_ASSERT(mType == EType::ReadbackBuffer);
break;
case EMode::Write:
JPH_ASSERT(mType == EType::UploadBuffer || mType == EType::ConstantBuffer);
mNeedsSync = true;
break;
}
return mComputeSystem->MapBuffer(mBufferCPU);
}
void ComputeBufferVK::UnmapInternal()
{
mComputeSystem->UnmapBuffer(mBufferCPU);
}
ComputeBufferResult ComputeBufferVK::CreateReadBackBuffer() const
{
return mComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::ReadbackBuffer, mSize, mStride);
}
JPH_NAMESPACE_END
#endif // JPH_USE_VK

View File

@ -0,0 +1,52 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Compute/ComputeBuffer.h>
#ifdef JPH_USE_VK
#include <Jolt/Compute/VK/BufferVK.h>
JPH_NAMESPACE_BEGIN
class ComputeSystemVK;
/// Buffer that can be read from / written to by a compute shader
class JPH_EXPORT ComputeBufferVK final : public ComputeBuffer
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Constructor
ComputeBufferVK(ComputeSystemVK *inComputeSystem, EType inType, uint64 inSize, uint inStride);
virtual ~ComputeBufferVK() override;
bool Initialize(const void *inData);
virtual ComputeBufferResult CreateReadBackBuffer() const override;
VkBuffer GetBufferCPU() const { return mBufferCPU.mBuffer; }
VkBuffer GetBufferGPU() const { return mBufferGPU.mBuffer; }
BufferVK ReleaseBufferCPU() const { BufferVK tmp = mBufferCPU; mBufferCPU = BufferVK(); return tmp; }
void Barrier(VkCommandBuffer inCommandBuffer, VkPipelineStageFlags inToStage, VkAccessFlagBits inToFlags, bool inForce) const;
bool SyncCPUToGPU(VkCommandBuffer inCommandBuffer) const;
private:
virtual void * MapInternal(EMode inMode) override;
virtual void UnmapInternal() override;
ComputeSystemVK * mComputeSystem;
mutable BufferVK mBufferCPU;
BufferVK mBufferGPU;
mutable bool mNeedsSync = false; ///< If this buffer needs to be synced from CPU to GPU
mutable VkAccessFlagBits mAccessFlagBits = VK_ACCESS_SHADER_READ_BIT; ///< Access flags of the last usage, used for barriers
mutable VkPipelineStageFlags mAccessStage = VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT; ///< Pipeline stage of the last usage, used for barriers
};
JPH_NAMESPACE_END
#endif // JPH_USE_VK

View File

@ -0,0 +1,304 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_VK
#include <Jolt/Compute/VK/ComputeQueueVK.h>
#include <Jolt/Compute/VK/ComputeBufferVK.h>
#include <Jolt/Compute/VK/ComputeSystemVK.h>
JPH_NAMESPACE_BEGIN
ComputeQueueVK::~ComputeQueueVK()
{
Wait();
VkDevice device = mComputeSystem->GetDevice();
if (mCommandBuffer != VK_NULL_HANDLE)
vkFreeCommandBuffers(device, mCommandPool, 1, &mCommandBuffer);
if (mCommandPool != VK_NULL_HANDLE)
vkDestroyCommandPool(device, mCommandPool, nullptr);
if (mDescriptorPool != VK_NULL_HANDLE)
vkDestroyDescriptorPool(device, mDescriptorPool, nullptr);
if (mFence != VK_NULL_HANDLE)
vkDestroyFence(device, mFence, nullptr);
}
bool ComputeQueueVK::Initialize(uint32 inComputeQueueIndex, ComputeQueueResult &outResult)
{
// Get the queue
VkDevice device = mComputeSystem->GetDevice();
vkGetDeviceQueue(device, inComputeQueueIndex, 0, &mQueue);
// Create a command pool
VkCommandPoolCreateInfo pool_info = {};
pool_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
pool_info.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
pool_info.queueFamilyIndex = inComputeQueueIndex;
if (VKFailed(vkCreateCommandPool(device, &pool_info, nullptr, &mCommandPool), outResult))
return false;
// Create descriptor pool
VkDescriptorPoolSize descriptor_pool_sizes[] = {
{ VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, 1024 },
{ VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 16 * 1024 },
};
VkDescriptorPoolCreateInfo descriptor_info = {};
descriptor_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
descriptor_info.poolSizeCount = (uint32)std::size(descriptor_pool_sizes);
descriptor_info.pPoolSizes = descriptor_pool_sizes;
descriptor_info.maxSets = 256;
if (VKFailed(vkCreateDescriptorPool(device, &descriptor_info, nullptr, &mDescriptorPool), outResult))
return false;
// Create a command buffer
VkCommandBufferAllocateInfo alloc_info = {};
alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
alloc_info.commandPool = mCommandPool;
alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
alloc_info.commandBufferCount = 1;
if (VKFailed(vkAllocateCommandBuffers(device, &alloc_info, &mCommandBuffer), outResult))
return false;
// Create a fence
VkFenceCreateInfo fence_info = {};
fence_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
if (VKFailed(vkCreateFence(device, &fence_info, nullptr, &mFence), outResult))
return false;
return true;
}
bool ComputeQueueVK::BeginCommandBuffer()
{
if (!mCommandBufferRecording)
{
VkCommandBufferBeginInfo begin_info = {};
begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
begin_info.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
if (VKFailed(vkBeginCommandBuffer(mCommandBuffer, &begin_info)))
return false;
mCommandBufferRecording = true;
}
return true;
}
void ComputeQueueVK::SetShader(const ComputeShader *inShader)
{
mShader = static_cast<const ComputeShaderVK *>(inShader);
mBufferInfos = mShader->GetBufferInfos();
}
void ComputeQueueVK::SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer)
{
if (inBuffer == nullptr)
return;
JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::ConstantBuffer);
if (!BeginCommandBuffer())
return;
const ComputeBufferVK *buffer = static_cast<const ComputeBufferVK *>(inBuffer);
buffer->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_ACCESS_UNIFORM_READ_BIT, false);
uint index = mShader->NameToBufferInfoIndex(inName);
JPH_ASSERT(mShader->GetLayoutBindings()[index].descriptorType == VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER);
mBufferInfos[index].buffer = buffer->GetBufferCPU();
mUsedBuffers.insert(buffer);
}
void ComputeQueueVK::SyncCPUToGPU(const ComputeBufferVK *inBuffer)
{
// Ensure that any CPU writes are visible to the GPU
if (inBuffer->SyncCPUToGPU(mCommandBuffer)
&& (inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer))
{
// After the first upload, the CPU buffer is no longer needed for Buffer and RWBuffer types
mDelayedFreedBuffers.push_back(inBuffer->ReleaseBufferCPU());
}
}
void ComputeQueueVK::SetBuffer(const char *inName, const ComputeBuffer *inBuffer)
{
if (inBuffer == nullptr)
return;
JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::UploadBuffer || inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
if (!BeginCommandBuffer())
return;
const ComputeBufferVK *buffer = static_cast<const ComputeBufferVK *>(inBuffer);
SyncCPUToGPU(buffer);
buffer->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_ACCESS_SHADER_READ_BIT, false);
uint index = mShader->NameToBufferInfoIndex(inName);
JPH_ASSERT(mShader->GetLayoutBindings()[index].descriptorType == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
mBufferInfos[index].buffer = buffer->GetBufferGPU();
mUsedBuffers.insert(buffer);
}
void ComputeQueueVK::SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier)
{
if (inBuffer == nullptr)
return;
JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
if (!BeginCommandBuffer())
return;
const ComputeBufferVK *buffer = static_cast<const ComputeBufferVK *>(inBuffer);
SyncCPUToGPU(buffer);
buffer->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VkAccessFlagBits(VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT), inBarrier == EBarrier::Yes);
uint index = mShader->NameToBufferInfoIndex(inName);
JPH_ASSERT(mShader->GetLayoutBindings()[index].descriptorType == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
mBufferInfos[index].buffer = buffer->GetBufferGPU();
mUsedBuffers.insert(buffer);
}
void ComputeQueueVK::ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc)
{
if (inDst == nullptr || inSrc == nullptr)
return;
JPH_ASSERT(inDst->GetType() == ComputeBuffer::EType::ReadbackBuffer);
if (!BeginCommandBuffer())
return;
const ComputeBufferVK *src_vk = static_cast<const ComputeBufferVK *>(inSrc);
const ComputeBufferVK *dst_vk = static_cast<ComputeBufferVK *>(inDst);
// Barrier to start reading from GPU buffer and writing to CPU buffer
src_vk->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_ACCESS_TRANSFER_READ_BIT, false);
dst_vk->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_ACCESS_TRANSFER_WRITE_BIT, false);
// Copy
VkBufferCopy copy = {};
copy.srcOffset = 0;
copy.dstOffset = 0;
copy.size = src_vk->GetSize() * src_vk->GetStride();
vkCmdCopyBuffer(mCommandBuffer, src_vk->GetBufferGPU(), dst_vk->GetBufferCPU(), 1, &copy);
// Barrier to indicate that CPU can read from the buffer
dst_vk->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_HOST_BIT, VK_ACCESS_HOST_READ_BIT, false);
mUsedBuffers.insert(src_vk);
mUsedBuffers.insert(dst_vk);
}
void ComputeQueueVK::Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ)
{
if (!BeginCommandBuffer())
return;
vkCmdBindPipeline(mCommandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, mShader->GetPipeline());
VkDevice device = mComputeSystem->GetDevice();
const Array<VkDescriptorSetLayoutBinding> &ds_bindings = mShader->GetLayoutBindings();
if (!ds_bindings.empty())
{
// Create a descriptor set
VkDescriptorSetAllocateInfo alloc_info = {};
alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
alloc_info.descriptorPool = mDescriptorPool;
alloc_info.descriptorSetCount = 1;
VkDescriptorSetLayout ds_layout = mShader->GetDescriptorSetLayout();
alloc_info.pSetLayouts = &ds_layout;
VkDescriptorSet descriptor_set;
if (VKFailed(vkAllocateDescriptorSets(device, &alloc_info, &descriptor_set)))
return;
// Write the values to the descriptor set
Array<VkWriteDescriptorSet> writes;
writes.reserve(ds_bindings.size());
for (uint32 i = 0; i < (uint32)ds_bindings.size(); ++i)
{
VkWriteDescriptorSet w = {};
w.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
w.dstSet = descriptor_set;
w.dstBinding = ds_bindings[i].binding;
w.dstArrayElement = 0;
w.descriptorCount = ds_bindings[i].descriptorCount;
w.descriptorType = ds_bindings[i].descriptorType;
w.pBufferInfo = &mBufferInfos[i];
writes.push_back(w);
}
vkUpdateDescriptorSets(device, (uint32)writes.size(), writes.data(), 0, nullptr);
// Bind the descriptor set
vkCmdBindDescriptorSets(mCommandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, mShader->GetPipelineLayout(), 0, 1, &descriptor_set, 0, nullptr);
}
vkCmdDispatch(mCommandBuffer, inThreadGroupsX, inThreadGroupsY, inThreadGroupsZ);
}
void ComputeQueueVK::Execute()
{
// End command buffer
if (!mCommandBufferRecording)
return;
if (VKFailed(vkEndCommandBuffer(mCommandBuffer)))
return;
mCommandBufferRecording = false;
// Reset fence
VkDevice device = mComputeSystem->GetDevice();
if (VKFailed(vkResetFences(device, 1, &mFence)))
return;
// Submit
VkSubmitInfo submit = {};
submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
submit.commandBufferCount = 1;
submit.pCommandBuffers = &mCommandBuffer;
if (VKFailed(vkQueueSubmit(mQueue, 1, &submit, mFence)))
return;
// Clear the current shader
mShader = nullptr;
// Mark that we're executing
mIsExecuting = true;
}
void ComputeQueueVK::Wait()
{
if (!mIsExecuting)
return;
// Wait for the work to complete
VkDevice device = mComputeSystem->GetDevice();
if (VKFailed(vkWaitForFences(device, 1, &mFence, VK_TRUE, UINT64_MAX)))
return;
// Reset command buffer so it can be reused
if (mCommandBuffer != VK_NULL_HANDLE)
vkResetCommandBuffer(mCommandBuffer, 0);
// Allow reusing the descriptors for next run
vkResetDescriptorPool(device, mDescriptorPool, 0);
// Buffers can be freed now
mUsedBuffers.clear();
// Free delayed buffers
for (BufferVK &buffer : mDelayedFreedBuffers)
mComputeSystem->FreeBuffer(buffer);
mDelayedFreedBuffers.clear();
mIsExecuting = false;
}
JPH_NAMESPACE_END
#endif // JPH_USE_VK

View File

@ -0,0 +1,66 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Compute/ComputeQueue.h>
#ifdef JPH_USE_VK
#include <Jolt/Compute/VK/ComputeShaderVK.h>
#include <Jolt/Compute/VK/BufferVK.h>
#include <Jolt/Core/UnorderedMap.h>
#include <Jolt/Core/UnorderedSet.h>
JPH_NAMESPACE_BEGIN
class ComputeSystemVK;
class ComputeBufferVK;
/// A command queue for Vulkan for executing compute workloads on the GPU.
class JPH_EXPORT ComputeQueueVK final : public ComputeQueue
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Constructor / Destructor
explicit ComputeQueueVK(ComputeSystemVK *inComputeSystem) : mComputeSystem(inComputeSystem) { }
virtual ~ComputeQueueVK() override;
/// Initialize the queue
bool Initialize(uint32 inComputeQueueIndex, ComputeQueueResult &outResult);
// See: ComputeQueue
virtual void SetShader(const ComputeShader *inShader) override;
virtual void SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
virtual void SetBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
virtual void SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier = EBarrier::Yes) override;
virtual void ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc) override;
virtual void Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ) override;
virtual void Execute() override;
virtual void Wait() override;
private:
bool BeginCommandBuffer();
// Copy the CPU buffer to the GPU buffer if needed
void SyncCPUToGPU(const ComputeBufferVK *inBuffer);
ComputeSystemVK * mComputeSystem;
VkQueue mQueue = VK_NULL_HANDLE;
VkCommandPool mCommandPool = VK_NULL_HANDLE;
VkDescriptorPool mDescriptorPool = VK_NULL_HANDLE;
VkCommandBuffer mCommandBuffer = VK_NULL_HANDLE;
bool mCommandBufferRecording = false; ///< If we are currently recording commands into the command buffer
VkFence mFence = VK_NULL_HANDLE;
bool mIsExecuting = false; ///< If Execute has been called and we are waiting for it to finish
RefConst<ComputeShaderVK> mShader; ///< Shader that has been activated
Array<VkDescriptorBufferInfo> mBufferInfos; ///< List of parameters that will be sent to the current shader
UnorderedSet<RefConst<ComputeBuffer>> mUsedBuffers; ///< Buffers that are in use by the current execution, these will be retained until execution is finished so that we don't free buffers that are in use
Array<BufferVK> mDelayedFreedBuffers; ///< Hardware buffers that need to be freed after execution is done
};
JPH_NAMESPACE_END
#endif // JPH_USE_VK

View File

@ -0,0 +1,232 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_VK
#include <Jolt/Compute/VK/ComputeShaderVK.h>
JPH_NAMESPACE_BEGIN
ComputeShaderVK::~ComputeShaderVK()
{
if (mShaderModule != VK_NULL_HANDLE)
vkDestroyShaderModule(mDevice, mShaderModule, nullptr);
if (mDescriptorSetLayout != VK_NULL_HANDLE)
vkDestroyDescriptorSetLayout(mDevice, mDescriptorSetLayout, nullptr);
if (mPipelineLayout != VK_NULL_HANDLE)
vkDestroyPipelineLayout(mDevice, mPipelineLayout, nullptr);
if (mPipeline != VK_NULL_HANDLE)
vkDestroyPipeline(mDevice, mPipeline, nullptr);
}
bool ComputeShaderVK::Initialize(const Array<uint8> &inSPVCode, VkBuffer inDummyBuffer, ComputeShaderResult &outResult)
{
const uint32 *spv_words = reinterpret_cast<const uint32 *>(inSPVCode.data());
size_t spv_word_count = inSPVCode.size() / sizeof(uint32);
// Minimal SPIR-V parser to extract name to binding info
UnorderedMap<uint32, String> id_to_name;
UnorderedMap<uint32, uint32> id_to_binding;
UnorderedMap<uint32, VkDescriptorType> id_to_descriptor_type;
UnorderedMap<uint32, uint32> pointer_to_pointee;
UnorderedMap<uint32, uint32> var_to_ptr_type;
size_t i = 5; // Skip 5 word header
while (i < spv_word_count)
{
// Parse next word
uint32 word = spv_words[i];
uint16 opcode = uint16(word & 0xffff);
uint16 word_count = uint16(word >> 16);
if (word_count == 0 || i + word_count > spv_word_count)
break;
switch (opcode)
{
case 5: // OpName
if (word_count >= 2)
{
uint32 target_id = spv_words[i + 1];
const char* name = reinterpret_cast<const char*>(&spv_words[i + 2]);
if (*name != 0)
id_to_name.insert({ target_id, name });
}
break;
case 16: // OpExecutionMode
if (word_count >= 6)
{
uint32 execution_mode = spv_words[i + 2];
if (execution_mode == 17) // LocalSize
{
// Assert that the group size provided matches the one in the shader
JPH_ASSERT(GetGroupSizeX() == spv_words[i + 3], "Group size X mismatch");
JPH_ASSERT(GetGroupSizeY() == spv_words[i + 4], "Group size Y mismatch");
JPH_ASSERT(GetGroupSizeZ() == spv_words[i + 5], "Group size Z mismatch");
}
}
break;
case 32: // OpTypePointer
if (word_count >= 4)
{
uint32 result_id = spv_words[i + 1];
uint32 type_id = spv_words[i + 3];
pointer_to_pointee.insert({ result_id, type_id });
}
break;
case 59: // OpVariable
if (word_count >= 3)
{
uint32 ptr_type_id = spv_words[i + 1];
uint32 result_id = spv_words[i + 2];
var_to_ptr_type.insert({ result_id, ptr_type_id });
}
break;
case 71: // OpDecorate
if (word_count >= 3)
{
uint32 target_id = spv_words[i + 1];
uint32 decoration = spv_words[i + 2];
if (decoration == 2) // Block
{
id_to_descriptor_type.insert({ target_id, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER });
}
else if (decoration == 3) // BufferBlock
{
id_to_descriptor_type.insert({ target_id, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER });
}
else if (decoration == 33 && word_count >= 4) // Binding
{
uint32 binding = spv_words[i + 3];
id_to_binding.insert({ target_id, binding });
}
}
break;
default:
break;
}
i += word_count;
}
// Build name to binding map
UnorderedMap<String, std::pair<uint32, VkDescriptorType>> name_to_binding;
for (const UnorderedMap<uint32, uint32>::value_type &entry : id_to_binding)
{
uint32 target_id = entry.first;
uint32 binding = entry.second;
// Get the name of the variable
UnorderedMap<uint32, String>::const_iterator it_name = id_to_name.find(target_id);
if (it_name != id_to_name.end())
{
// Find variable that links to the target
UnorderedMap<uint32, uint32>::const_iterator it_var_ptr = var_to_ptr_type.find(target_id);
if (it_var_ptr != var_to_ptr_type.end())
{
// Find type pointed at
uint32 ptr_type = it_var_ptr->second;
UnorderedMap<uint32, uint32>::const_iterator it_pointee = pointer_to_pointee.find(ptr_type);
if (it_pointee != pointer_to_pointee.end())
{
uint32 pointee_type = it_pointee->second;
// Find descriptor type
UnorderedMap<uint32, VkDescriptorType>::iterator it_descriptor_type = id_to_descriptor_type.find(pointee_type);
VkDescriptorType descriptor_type = it_descriptor_type != id_to_descriptor_type.end() ? it_descriptor_type->second : VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
name_to_binding.insert({ it_name->second, { binding, descriptor_type } });
continue;
}
}
}
}
// Create layout bindings and buffer infos
if (!name_to_binding.empty())
{
mLayoutBindings.reserve(name_to_binding.size());
mBufferInfos.reserve(name_to_binding.size());
mBindingNames.reserve(name_to_binding.size());
for (const UnorderedMap<String, std::pair<uint32, VkDescriptorType>>::value_type &b : name_to_binding)
{
const String &name = b.first;
uint binding = b.second.first;
VkDescriptorType descriptor_type = b.second.second;
VkDescriptorSetLayoutBinding l = {};
l.binding = binding;
l.descriptorCount = 1;
l.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
l.descriptorType = descriptor_type;
mLayoutBindings.push_back(l);
mBindingNames.push_back(name); // Add all strings to a pool to keep them alive
mNameToBufferInfoIndex[string_view(mBindingNames.back())] = (uint32)mBufferInfos.size();
VkDescriptorBufferInfo bi = {};
bi.offset = 0;
bi.range = VK_WHOLE_SIZE;
bi.buffer = inDummyBuffer; // Avoid: The Vulkan spec states: If the nullDescriptor feature is not enabled, buffer must not be VK_NULL_HANDLE
mBufferInfos.push_back(bi);
}
// Create descriptor set layout
VkDescriptorSetLayoutCreateInfo layout_info = {};
layout_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
layout_info.bindingCount = (uint32)mLayoutBindings.size();
layout_info.pBindings = mLayoutBindings.data();
if (VKFailed(vkCreateDescriptorSetLayout(mDevice, &layout_info, nullptr, &mDescriptorSetLayout), outResult))
return false;
}
// Create pipeline layout
VkPipelineLayoutCreateInfo pl_info = {};
pl_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
pl_info.setLayoutCount = mDescriptorSetLayout != VK_NULL_HANDLE ? 1 : 0;
pl_info.pSetLayouts = mDescriptorSetLayout != VK_NULL_HANDLE ? &mDescriptorSetLayout : nullptr;
if (VKFailed(vkCreatePipelineLayout(mDevice, &pl_info, nullptr, &mPipelineLayout), outResult))
return false;
// Create shader module
VkShaderModuleCreateInfo create_info = {};
create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
create_info.codeSize = inSPVCode.size();
create_info.pCode = spv_words;
if (VKFailed(vkCreateShaderModule(mDevice, &create_info, nullptr, &mShaderModule), outResult))
return false;
// Create compute pipeline
VkComputePipelineCreateInfo pipe_info = {};
pipe_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
pipe_info.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
pipe_info.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
pipe_info.stage.module = mShaderModule;
pipe_info.stage.pName = "main";
pipe_info.layout = mPipelineLayout;
if (VKFailed(vkCreateComputePipelines(mDevice, VK_NULL_HANDLE, 1, &pipe_info, nullptr, &mPipeline), outResult))
return false;
return true;
}
uint32 ComputeShaderVK::NameToBufferInfoIndex(const char *inName) const
{
UnorderedMap<string_view, uint>::const_iterator it = mNameToBufferInfoIndex.find(inName);
JPH_ASSERT(it != mNameToBufferInfoIndex.end());
return it->second;
}
JPH_NAMESPACE_END
#endif // JPH_USE_VK

View File

@ -0,0 +1,53 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Compute/ComputeShader.h>
#ifdef JPH_USE_VK
#include <Jolt/Compute/VK/IncludeVK.h>
#include <Jolt/Core/UnorderedMap.h>
JPH_NAMESPACE_BEGIN
/// Compute shader handle for Vulkan
class JPH_EXPORT ComputeShaderVK : public ComputeShader
{
public:
JPH_OVERRIDE_NEW_DELETE
/// Constructor / destructor
ComputeShaderVK(VkDevice inDevice, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) : ComputeShader(inGroupSizeX, inGroupSizeY, inGroupSizeZ), mDevice(inDevice) { }
virtual ~ComputeShaderVK() override;
/// Initialize from SPIR-V code
bool Initialize(const Array<uint8> &inSPVCode, VkBuffer inDummyBuffer, ComputeShaderResult &outResult);
/// Get index of parameter in buffer infos
uint32 NameToBufferInfoIndex(const char *inName) const;
/// Getters
VkPipeline GetPipeline() const { return mPipeline; }
VkPipelineLayout GetPipelineLayout() const { return mPipelineLayout; }
VkDescriptorSetLayout GetDescriptorSetLayout() const { return mDescriptorSetLayout; }
const Array<VkDescriptorSetLayoutBinding> &GetLayoutBindings() const { return mLayoutBindings; }
const Array<VkDescriptorBufferInfo> &GetBufferInfos() const { return mBufferInfos; }
private:
VkDevice mDevice;
VkShaderModule mShaderModule = VK_NULL_HANDLE;
VkPipelineLayout mPipelineLayout = VK_NULL_HANDLE;
VkPipeline mPipeline = VK_NULL_HANDLE;
VkDescriptorSetLayout mDescriptorSetLayout = VK_NULL_HANDLE;
Array<String> mBindingNames; ///< A list of binding names, mNameToBufferInfoIndex points to these strings
UnorderedMap<string_view, uint32> mNameToBufferInfoIndex; ///< Binding name to buffer index, using a string_view so we can do find() without an allocation
Array<VkDescriptorSetLayoutBinding> mLayoutBindings;
Array<VkDescriptorBufferInfo> mBufferInfos;
};
JPH_NAMESPACE_END
#endif // JPH_USE_VK

View File

@ -0,0 +1,118 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_VK
#include <Jolt/Compute/VK/ComputeSystemVK.h>
#include <Jolt/Compute/VK/ComputeShaderVK.h>
#include <Jolt/Compute/VK/ComputeBufferVK.h>
#include <Jolt/Compute/VK/ComputeQueueVK.h>
JPH_NAMESPACE_BEGIN
JPH_IMPLEMENT_RTTI_ABSTRACT(ComputeSystemVK)
{
JPH_ADD_BASE_CLASS(ComputeSystemVK, ComputeSystem)
}
bool ComputeSystemVK::Initialize(VkPhysicalDevice inPhysicalDevice, VkDevice inDevice, uint32 inComputeQueueIndex, ComputeSystemResult &outResult)
{
mPhysicalDevice = inPhysicalDevice;
mDevice = inDevice;
mComputeQueueIndex = inComputeQueueIndex;
// Get function to set a debug name
mVkSetDebugUtilsObjectNameEXT = reinterpret_cast<PFN_vkSetDebugUtilsObjectNameEXT>(reinterpret_cast<void *>(vkGetDeviceProcAddr(mDevice, "vkSetDebugUtilsObjectNameEXT")));
if (!InitializeMemory())
{
outResult.SetError("Failed to initialize memory subsystem");
return false;
}
// Create the dummy buffer. This is used to bind to shaders for which we have no buffer. We can't rely on VK_EXT_robustness2 being available to set nullDescriptor = VK_TRUE (it is unavailable on macOS).
if (!CreateBuffer(1024, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, mDummyBuffer))
{
outResult.SetError("Failed to create dummy buffer");
return false;
}
return true;
}
void ComputeSystemVK::Shutdown()
{
if (mDevice != VK_NULL_HANDLE)
vkDeviceWaitIdle(mDevice);
// Free the dummy buffer
FreeBuffer(mDummyBuffer);
ShutdownMemory();
}
ComputeShaderResult ComputeSystemVK::CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ)
{
ComputeShaderResult result;
// Read shader source file
Array<uint8> data;
String file_name = String(inName) + ".spv";
String error;
if (!mShaderLoader(file_name.c_str(), data, error))
{
result.SetError(error);
return result;
}
Ref<ComputeShaderVK> shader = new ComputeShaderVK(mDevice, inGroupSizeX, inGroupSizeY, inGroupSizeZ);
if (!shader->Initialize(data, mDummyBuffer.mBuffer, result))
return result;
// Name the pipeline so we can easily find it in a profile
if (mVkSetDebugUtilsObjectNameEXT != nullptr)
{
VkDebugUtilsObjectNameInfoEXT info = {};
info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_OBJECT_NAME_INFO_EXT;
info.pNext = nullptr;
info.objectType = VK_OBJECT_TYPE_PIPELINE;
info.objectHandle = (uint64)shader->GetPipeline();
info.pObjectName = inName;
mVkSetDebugUtilsObjectNameEXT(mDevice, &info);
}
result.Set(shader.GetPtr());
return result;
}
ComputeBufferResult ComputeSystemVK::CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData)
{
ComputeBufferResult result;
Ref<ComputeBufferVK> buffer = new ComputeBufferVK(this, inType, inSize, inStride);
if (!buffer->Initialize(inData))
{
result.SetError("Failed to create compute buffer");
return result;
}
result.Set(buffer.GetPtr());
return result;
}
ComputeQueueResult ComputeSystemVK::CreateComputeQueue()
{
ComputeQueueResult result;
Ref<ComputeQueueVK> q = new ComputeQueueVK(this);
if (!q->Initialize(mComputeQueueIndex, result))
return result;
result.Set(q.GetPtr());
return result;
}
JPH_NAMESPACE_END
#endif // JPH_USE_VK

View File

@ -0,0 +1,57 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Compute/ComputeSystem.h>
#ifdef JPH_USE_VK
#include <Jolt/Compute/VK/ComputeQueueVK.h>
JPH_NAMESPACE_BEGIN
/// Interface to run a workload on the GPU using Vulkan.
/// Minimal implementation that can integrate with your own Vulkan setup.
class JPH_EXPORT ComputeSystemVK : public ComputeSystem
{
public:
JPH_DECLARE_RTTI_ABSTRACT(JPH_EXPORT, ComputeSystemVK)
// Initialize / shutdown the compute system
bool Initialize(VkPhysicalDevice inPhysicalDevice, VkDevice inDevice, uint32 inComputeQueueIndex, ComputeSystemResult &outResult);
void Shutdown();
// See: ComputeSystem
virtual ComputeShaderResult CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) override;
virtual ComputeBufferResult CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData = nullptr) override;
virtual ComputeQueueResult CreateComputeQueue() override;
/// Access to the Vulkan device
VkDevice GetDevice() const { return mDevice; }
/// Allow the application to override buffer creation and memory mapping in case it uses its own allocator
virtual bool CreateBuffer(VkDeviceSize inSize, VkBufferUsageFlags inUsage, VkMemoryPropertyFlags inProperties, BufferVK &outBuffer) = 0;
virtual void FreeBuffer(BufferVK &ioBuffer) = 0;
virtual void * MapBuffer(BufferVK &ioBuffer) = 0;
virtual void UnmapBuffer(BufferVK &ioBuffer) = 0;
protected:
/// Initialize / shutdown the memory subsystem
virtual bool InitializeMemory() = 0;
virtual void ShutdownMemory() = 0;
VkPhysicalDevice mPhysicalDevice = VK_NULL_HANDLE;
VkDevice mDevice = VK_NULL_HANDLE;
uint32 mComputeQueueIndex = 0;
PFN_vkSetDebugUtilsObjectNameEXT mVkSetDebugUtilsObjectNameEXT = nullptr;
private:
// Buffer that can be bound when we have no buffer
BufferVK mDummyBuffer;
};
JPH_NAMESPACE_END
#endif // JPH_USE_VK

View File

@ -0,0 +1,330 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_VK
#include <Jolt/Compute/VK/ComputeSystemVKImpl.h>
#include <Jolt/Core/QuickSort.h>
JPH_NAMESPACE_BEGIN
JPH_IMPLEMENT_RTTI_VIRTUAL(ComputeSystemVKImpl)
{
JPH_ADD_BASE_CLASS(ComputeSystemVKImpl, ComputeSystemVKWithAllocator)
}
#ifdef JPH_DEBUG
static VKAPI_ATTR VkBool32 VKAPI_CALL sVulkanDebugCallback(VkDebugUtilsMessageSeverityFlagBitsEXT inSeverity, [[maybe_unused]] VkDebugUtilsMessageTypeFlagsEXT inType, const VkDebugUtilsMessengerCallbackDataEXT *inCallbackData, [[maybe_unused]] void *inUserData)
{
if (inSeverity & (VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT))
Trace("VK: %s", inCallbackData->pMessage);
JPH_ASSERT((inSeverity & VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT) == 0);
return VK_FALSE;
}
#endif // JPH_DEBUG
ComputeSystemVKImpl::~ComputeSystemVKImpl()
{
ComputeSystemVK::Shutdown();
if (mDevice != VK_NULL_HANDLE)
vkDestroyDevice(mDevice, nullptr);
#ifdef JPH_DEBUG
PFN_vkDestroyDebugUtilsMessengerEXT vkDestroyDebugUtilsMessengerEXT = (PFN_vkDestroyDebugUtilsMessengerEXT)(void *)vkGetInstanceProcAddr(mInstance, "vkDestroyDebugUtilsMessengerEXT");
if (mInstance != VK_NULL_HANDLE && mDebugMessenger != VK_NULL_HANDLE && vkDestroyDebugUtilsMessengerEXT != nullptr)
vkDestroyDebugUtilsMessengerEXT(mInstance, mDebugMessenger, nullptr);
#endif
if (mInstance != VK_NULL_HANDLE)
vkDestroyInstance(mInstance, nullptr);
}
bool ComputeSystemVKImpl::Initialize(ComputeSystemResult &outResult)
{
// Required instance extensions
Array<const char *> required_instance_extensions;
required_instance_extensions.push_back(VK_KHR_SURFACE_EXTENSION_NAME);
required_instance_extensions.push_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
#ifdef JPH_PLATFORM_MACOS
required_instance_extensions.push_back("VK_KHR_portability_enumeration");
required_instance_extensions.push_back("VK_KHR_get_physical_device_properties2");
#endif
GetInstanceExtensions(required_instance_extensions);
// Required device extensions
Array<const char *> required_device_extensions;
required_device_extensions.push_back(VK_EXT_SCALAR_BLOCK_LAYOUT_EXTENSION_NAME);
#ifdef JPH_PLATFORM_MACOS
required_device_extensions.push_back("VK_KHR_portability_subset"); // VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME
#endif
GetDeviceExtensions(required_device_extensions);
// Query supported instance extensions
uint32 instance_extension_count = 0;
if (VKFailed(vkEnumerateInstanceExtensionProperties(nullptr, &instance_extension_count, nullptr), outResult))
return false;
Array<VkExtensionProperties> instance_extensions;
instance_extensions.resize(instance_extension_count);
if (VKFailed(vkEnumerateInstanceExtensionProperties(nullptr, &instance_extension_count, instance_extensions.data()), outResult))
return false;
// Query supported validation layers
uint32 validation_layer_count;
vkEnumerateInstanceLayerProperties(&validation_layer_count, nullptr);
Array<VkLayerProperties> validation_layers(validation_layer_count);
vkEnumerateInstanceLayerProperties(&validation_layer_count, validation_layers.data());
VkApplicationInfo app_info = {};
app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
app_info.apiVersion = VK_API_VERSION_1_1;
// Create Vulkan instance
VkInstanceCreateInfo instance_create_info = {};
instance_create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
#ifdef JPH_PLATFORM_MACOS
instance_create_info.flags = VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR;
#endif
instance_create_info.pApplicationInfo = &app_info;
#ifdef JPH_DEBUG
// Enable validation layer if supported
const char *desired_validation_layers[] = { "VK_LAYER_KHRONOS_validation" };
for (const VkLayerProperties &p : validation_layers)
if (strcmp(desired_validation_layers[0], p.layerName) == 0)
{
instance_create_info.enabledLayerCount = 1;
instance_create_info.ppEnabledLayerNames = desired_validation_layers;
break;
}
// Setup debug messenger callback if the extension is supported
VkDebugUtilsMessengerCreateInfoEXT messenger_create_info = {};
for (const VkExtensionProperties &ext : instance_extensions)
if (strcmp(VK_EXT_DEBUG_UTILS_EXTENSION_NAME, ext.extensionName) == 0)
{
messenger_create_info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT;
messenger_create_info.messageSeverity = VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT;
messenger_create_info.messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT;
messenger_create_info.pfnUserCallback = sVulkanDebugCallback;
instance_create_info.pNext = &messenger_create_info;
required_instance_extensions.push_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
break;
}
#endif
instance_create_info.enabledExtensionCount = (uint32)required_instance_extensions.size();
instance_create_info.ppEnabledExtensionNames = required_instance_extensions.data();
if (VKFailed(vkCreateInstance(&instance_create_info, nullptr, &mInstance), outResult))
return false;
#ifdef JPH_DEBUG
// Finalize debug messenger callback
PFN_vkCreateDebugUtilsMessengerEXT vkCreateDebugUtilsMessengerEXT = (PFN_vkCreateDebugUtilsMessengerEXT)(std::uintptr_t)vkGetInstanceProcAddr(mInstance, "vkCreateDebugUtilsMessengerEXT");
if (vkCreateDebugUtilsMessengerEXT != nullptr)
if (VKFailed(vkCreateDebugUtilsMessengerEXT(mInstance, &messenger_create_info, nullptr, &mDebugMessenger), outResult))
return false;
#endif
// Notify that instance has been created
OnInstanceCreated();
// Select device
uint32 device_count = 0;
if (VKFailed(vkEnumeratePhysicalDevices(mInstance, &device_count, nullptr), outResult))
return false;
Array<VkPhysicalDevice> devices;
devices.resize(device_count);
if (VKFailed(vkEnumeratePhysicalDevices(mInstance, &device_count, devices.data()), outResult))
return false;
struct Device
{
VkPhysicalDevice mPhysicalDevice;
String mName;
VkSurfaceFormatKHR mFormat;
uint32 mGraphicsQueueIndex;
uint32 mPresentQueueIndex;
uint32 mComputeQueueIndex;
int mScore;
};
Array<Device> available_devices;
for (VkPhysicalDevice device : devices)
{
// Get device properties
VkPhysicalDeviceProperties properties;
vkGetPhysicalDeviceProperties(device, &properties);
// Test if it is an appropriate type
int score = 0;
switch (properties.deviceType)
{
case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU:
score = 30;
break;
case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU:
score = 20;
break;
case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU:
score = 10;
break;
case VK_PHYSICAL_DEVICE_TYPE_CPU:
score = 5;
break;
case VK_PHYSICAL_DEVICE_TYPE_OTHER:
case VK_PHYSICAL_DEVICE_TYPE_MAX_ENUM:
continue;
}
// Check if the device supports all our required extensions
uint32 device_extension_count;
vkEnumerateDeviceExtensionProperties(device, nullptr, &device_extension_count, nullptr);
Array<VkExtensionProperties> available_extensions;
available_extensions.resize(device_extension_count);
vkEnumerateDeviceExtensionProperties(device, nullptr, &device_extension_count, available_extensions.data());
int found_extensions = 0;
for (const char *required_device_extension : required_device_extensions)
for (const VkExtensionProperties &ext : available_extensions)
if (strcmp(required_device_extension, ext.extensionName) == 0)
{
found_extensions++;
break;
}
if (found_extensions != int(required_device_extensions.size()))
continue;
// Find the right queues
uint32 queue_family_count = 0;
vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, nullptr);
Array<VkQueueFamilyProperties> queue_families;
queue_families.resize(queue_family_count);
vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, queue_families.data());
uint32 graphics_queue = ~uint32(0);
uint32 present_queue = ~uint32(0);
uint32 compute_queue = ~uint32(0);
for (uint32 i = 0; i < uint32(queue_families.size()); ++i)
{
if (queue_families[i].queueFlags & VK_QUEUE_GRAPHICS_BIT)
{
graphics_queue = i;
if (queue_families[i].queueFlags & VK_QUEUE_COMPUTE_BIT)
compute_queue = i;
}
if (HasPresentSupport(device, i))
present_queue = i;
if (graphics_queue != ~uint32(0) && present_queue != ~uint32(0) && compute_queue != ~uint32(0))
break;
}
if (graphics_queue == ~uint32(0) || present_queue == ~uint32(0) || compute_queue == ~uint32(0))
continue;
// Select surface format
VkSurfaceFormatKHR selected_format = SelectFormat(device);
if (selected_format.format == VK_FORMAT_UNDEFINED)
continue;
// Add the device
available_devices.push_back({ device, properties.deviceName, selected_format, graphics_queue, present_queue, compute_queue, score });
}
if (available_devices.empty())
{
outResult.SetError("No suitable Vulkan device found");
return false;
}
// Sort the devices by score
QuickSort(available_devices.begin(), available_devices.end(), [](const Device &inLHS, const Device &inRHS) {
return inLHS.mScore > inRHS.mScore;
});
const Device &selected_device = available_devices[0];
// Create device
float queue_priority = 1.0f;
VkDeviceQueueCreateInfo queue_create_info[3] = {};
for (VkDeviceQueueCreateInfo &q : queue_create_info)
{
q.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
q.queueCount = 1;
q.pQueuePriorities = &queue_priority;
}
uint32 num_queues = 0;
queue_create_info[num_queues++].queueFamilyIndex = selected_device.mGraphicsQueueIndex;
bool found = false;
for (uint32 i = 0; i < num_queues; ++i)
if (queue_create_info[i].queueFamilyIndex == selected_device.mPresentQueueIndex)
{
found = true;
break;
}
if (!found)
queue_create_info[num_queues++].queueFamilyIndex = selected_device.mPresentQueueIndex;
found = false;
for (uint32 i = 0; i < num_queues; ++i)
if (queue_create_info[i].queueFamilyIndex == selected_device.mComputeQueueIndex)
{
found = true;
break;
}
if (!found)
queue_create_info[num_queues++].queueFamilyIndex = selected_device.mComputeQueueIndex;
VkPhysicalDeviceScalarBlockLayoutFeatures enable_scalar_block = {};
enable_scalar_block.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SCALAR_BLOCK_LAYOUT_FEATURES;
enable_scalar_block.scalarBlockLayout = VK_TRUE;
VkPhysicalDeviceFeatures2 enabled_features2 = {};
enabled_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
GetEnabledFeatures(enabled_features2);
enable_scalar_block.pNext = enabled_features2.pNext;
enabled_features2.pNext = &enable_scalar_block;
VkDeviceCreateInfo device_create_info = {};
device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
device_create_info.queueCreateInfoCount = num_queues;
device_create_info.pQueueCreateInfos = queue_create_info;
device_create_info.enabledLayerCount = instance_create_info.enabledLayerCount;
device_create_info.ppEnabledLayerNames = instance_create_info.ppEnabledLayerNames;
device_create_info.enabledExtensionCount = uint32(required_device_extensions.size());
device_create_info.ppEnabledExtensionNames = required_device_extensions.data();
device_create_info.pNext = &enabled_features2;
device_create_info.pEnabledFeatures = nullptr;
VkDevice device = VK_NULL_HANDLE;
if (VKFailed(vkCreateDevice(selected_device.mPhysicalDevice, &device_create_info, nullptr, &device), outResult))
return false;
// Get the queues
mGraphicsQueueIndex = selected_device.mGraphicsQueueIndex;
mPresentQueueIndex = selected_device.mPresentQueueIndex;
vkGetDeviceQueue(device, mGraphicsQueueIndex, 0, &mGraphicsQueue);
vkGetDeviceQueue(device, mPresentQueueIndex, 0, &mPresentQueue);
// Store selected format
mSelectedFormat = selected_device.mFormat;
// Initialize the compute system
return ComputeSystemVK::Initialize(selected_device.mPhysicalDevice, device, selected_device.mComputeQueueIndex, outResult);
}
ComputeSystemResult CreateComputeSystemVK()
{
ComputeSystemResult result;
Ref<ComputeSystemVKImpl> compute = new ComputeSystemVKImpl;
if (!compute->Initialize(result))
return result;
result.Set(compute.GetPtr());
return result;
}
JPH_NAMESPACE_END
#endif // JPH_USE_VK

View File

@ -0,0 +1,57 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#ifdef JPH_USE_VK
#include <Jolt/Compute/VK/ComputeSystemVKWithAllocator.h>
JPH_NAMESPACE_BEGIN
/// Implementation of ComputeSystemVK that fully initializes Vulkan
class JPH_EXPORT ComputeSystemVKImpl : public ComputeSystemVKWithAllocator
{
public:
JPH_DECLARE_RTTI_VIRTUAL(JPH_EXPORT, ComputeSystemVKImpl)
/// Destructor
virtual ~ComputeSystemVKImpl() override;
/// Initialize the compute system
bool Initialize(ComputeSystemResult &outResult);
protected:
/// Override to perform actions once the instance has been created
virtual void OnInstanceCreated() { /* Do nothing */ }
/// Override to add platform specific instance extensions
virtual void GetInstanceExtensions(Array<const char *> &outExtensions) { /* Add nothing */ }
/// Override to add platform specific device extensions
virtual void GetDeviceExtensions(Array<const char *> &outExtensions) { /* Add nothing */ }
/// Override to enable specific features
virtual void GetEnabledFeatures(VkPhysicalDeviceFeatures2 &ioFeatures) { /* Add nothing */ }
/// Override to check for present support on a given device and queue family
virtual bool HasPresentSupport(VkPhysicalDevice inDevice, uint32 inQueueFamilyIndex) { return true; }
/// Override to select the surface format
virtual VkSurfaceFormatKHR SelectFormat(VkPhysicalDevice inDevice) { return { VK_FORMAT_B8G8R8A8_UNORM, VK_COLOR_SPACE_SRGB_NONLINEAR_KHR }; }
VkInstance mInstance = VK_NULL_HANDLE;
#ifdef JPH_DEBUG
VkDebugUtilsMessengerEXT mDebugMessenger = VK_NULL_HANDLE;
#endif
uint32 mGraphicsQueueIndex = 0;
uint32 mPresentQueueIndex = 0;
VkQueue mGraphicsQueue = VK_NULL_HANDLE;
VkQueue mPresentQueue = VK_NULL_HANDLE;
VkSurfaceFormatKHR mSelectedFormat;
};
JPH_NAMESPACE_END
#endif // JPH_USE_VK

View File

@ -0,0 +1,172 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#include <Jolt/Jolt.h>
#ifdef JPH_USE_VK
#include <Jolt/Compute/VK/ComputeSystemVKWithAllocator.h>
#include <Jolt/Compute/VK/ComputeShaderVK.h>
#include <Jolt/Compute/VK/ComputeBufferVK.h>
#include <Jolt/Compute/VK/ComputeQueueVK.h>
JPH_NAMESPACE_BEGIN
JPH_IMPLEMENT_RTTI_VIRTUAL(ComputeSystemVKWithAllocator)
{
JPH_ADD_BASE_CLASS(ComputeSystemVKWithAllocator, ComputeSystemVK)
}
bool ComputeSystemVKWithAllocator::InitializeMemory()
{
// Get memory properties
vkGetPhysicalDeviceMemoryProperties(mPhysicalDevice, &mMemoryProperties);
return true;
}
void ComputeSystemVKWithAllocator::ShutdownMemory()
{
// Free all memory
for (const MemoryCache::value_type &mc : mMemoryCache)
for (const Memory &m : mc.second)
if (m.mOffset == 0)
FreeMemory(*m.mMemory);
mMemoryCache.clear();
}
uint32 ComputeSystemVKWithAllocator::FindMemoryType(uint32 inTypeFilter, VkMemoryPropertyFlags inProperties) const
{
for (uint32 i = 0; i < mMemoryProperties.memoryTypeCount; i++)
if ((inTypeFilter & (1 << i))
&& (mMemoryProperties.memoryTypes[i].propertyFlags & inProperties) == inProperties)
return i;
JPH_ASSERT(false, "Failed to find memory type!");
return 0;
}
void ComputeSystemVKWithAllocator::AllocateMemory(VkDeviceSize inSize, uint32 inMemoryTypeBits, VkMemoryPropertyFlags inProperties, MemoryVK &ioMemory)
{
JPH_ASSERT(ioMemory.mMemory == VK_NULL_HANDLE);
ioMemory.mSize = inSize;
ioMemory.mProperties = inProperties;
VkMemoryAllocateInfo alloc_info = {};
alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
alloc_info.allocationSize = inSize;
alloc_info.memoryTypeIndex = FindMemoryType(inMemoryTypeBits, inProperties);
vkAllocateMemory(mDevice, &alloc_info, nullptr, &ioMemory.mMemory);
}
void ComputeSystemVKWithAllocator::FreeMemory(MemoryVK &ioMemory)
{
vkFreeMemory(mDevice, ioMemory.mMemory, nullptr);
ioMemory.mMemory = VK_NULL_HANDLE;
}
bool ComputeSystemVKWithAllocator::CreateBuffer(VkDeviceSize inSize, VkBufferUsageFlags inUsage, VkMemoryPropertyFlags inProperties, BufferVK &outBuffer)
{
// Create a new buffer
outBuffer.mSize = inSize;
VkBufferCreateInfo create_info = {};
create_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
create_info.size = inSize;
create_info.usage = inUsage;
create_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
if (VKFailed(vkCreateBuffer(mDevice, &create_info, nullptr, &outBuffer.mBuffer)))
{
outBuffer.mBuffer = VK_NULL_HANDLE;
return false;
}
VkMemoryRequirements mem_requirements;
vkGetBufferMemoryRequirements(mDevice, outBuffer.mBuffer, &mem_requirements);
if (mem_requirements.size > cMaxAllocSize)
{
// Allocate block directly
Ref<MemoryVK> memory_vk = new MemoryVK();
memory_vk->mBufferSize = mem_requirements.size;
AllocateMemory(mem_requirements.size, mem_requirements.memoryTypeBits, inProperties, *memory_vk);
outBuffer.mMemory = memory_vk;
outBuffer.mOffset = 0;
}
else
{
// Round allocation to the next power of 2 so that we can use a simple block based allocator
VkDeviceSize buffer_size = max(VkDeviceSize(GetNextPowerOf2(uint32(mem_requirements.size))), cMinAllocSize);
// Ensure that we have memory available from the right pool
Array<Memory> &mem_array = mMemoryCache[{ buffer_size, inProperties }];
if (mem_array.empty())
{
// Allocate a bigger block
Ref<MemoryVK> memory_vk = new MemoryVK();
memory_vk->mBufferSize = buffer_size;
AllocateMemory(cBlockSize, mem_requirements.memoryTypeBits, inProperties, *memory_vk);
// Divide into sub blocks
for (VkDeviceSize offset = 0; offset < cBlockSize; offset += buffer_size)
mem_array.push_back({ memory_vk, offset });
}
// Claim memory from the pool
Memory &memory = mem_array.back();
outBuffer.mMemory = memory.mMemory;
outBuffer.mOffset = memory.mOffset;
mem_array.pop_back();
}
// Bind the memory to the buffer
vkBindBufferMemory(mDevice, outBuffer.mBuffer, outBuffer.mMemory->mMemory, outBuffer.mOffset);
return true;
}
void ComputeSystemVKWithAllocator::FreeBuffer(BufferVK &ioBuffer)
{
if (ioBuffer.mBuffer != VK_NULL_HANDLE)
{
// Destroy the buffer
vkDestroyBuffer(mDevice, ioBuffer.mBuffer, nullptr);
ioBuffer.mBuffer = VK_NULL_HANDLE;
// Hand the memory back to the cache
VkDeviceSize buffer_size = ioBuffer.mMemory->mBufferSize;
if (buffer_size > cMaxAllocSize)
FreeMemory(*ioBuffer.mMemory);
else
mMemoryCache[{ buffer_size, ioBuffer.mMemory->mProperties }].push_back({ ioBuffer.mMemory, ioBuffer.mOffset });
ioBuffer = BufferVK();
}
}
void *ComputeSystemVKWithAllocator::MapBuffer(BufferVK& ioBuffer)
{
if (++ioBuffer.mMemory->mMappedCount == 1
&& VKFailed(vkMapMemory(mDevice, ioBuffer.mMemory->mMemory, 0, VK_WHOLE_SIZE, 0, &ioBuffer.mMemory->mMappedPtr)))
{
ioBuffer.mMemory->mMappedCount = 0;
return nullptr;
}
return static_cast<uint8 *>(ioBuffer.mMemory->mMappedPtr) + ioBuffer.mOffset;
}
void ComputeSystemVKWithAllocator::UnmapBuffer(BufferVK& ioBuffer)
{
JPH_ASSERT(ioBuffer.mMemory->mMappedCount > 0);
if (--ioBuffer.mMemory->mMappedCount == 0)
{
vkUnmapMemory(mDevice, ioBuffer.mMemory->mMemory);
ioBuffer.mMemory->mMappedPtr = nullptr;
}
}
JPH_NAMESPACE_END
#endif // JPH_USE_VK

View File

@ -0,0 +1,70 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#ifdef JPH_USE_VK
#include <Jolt/Compute/VK/ComputeSystemVK.h>
#include <Jolt/Core/UnorderedMap.h>
JPH_NAMESPACE_BEGIN
/// This extends ComputeSystemVK to provide a default implementation for memory allocation and mapping.
/// It uses a simple block based allocator to reduce the number of allocations done to Vulkan.
class JPH_EXPORT ComputeSystemVKWithAllocator : public ComputeSystemVK
{
public:
JPH_DECLARE_RTTI_VIRTUAL(JPH_EXPORT, ComputeSystemVKWithAllocator)
/// Allow the application to override buffer creation and memory mapping in case it uses its own allocator
virtual bool CreateBuffer(VkDeviceSize inSize, VkBufferUsageFlags inUsage, VkMemoryPropertyFlags inProperties, BufferVK &outBuffer) override;
virtual void FreeBuffer(BufferVK &ioBuffer) override;
virtual void * MapBuffer(BufferVK &ioBuffer) override;
virtual void UnmapBuffer(BufferVK &ioBuffer) override;
protected:
virtual bool InitializeMemory() override;
virtual void ShutdownMemory() override;
uint32 FindMemoryType(uint32 inTypeFilter, VkMemoryPropertyFlags inProperties) const;
void AllocateMemory(VkDeviceSize inSize, uint32 inMemoryTypeBits, VkMemoryPropertyFlags inProperties, MemoryVK &ioMemory);
void FreeMemory(MemoryVK &ioMemory);
VkPhysicalDeviceMemoryProperties mMemoryProperties;
private:
// Smaller allocations (from cMinAllocSize to cMaxAllocSize) will be done in blocks of cBlockSize bytes.
// We do this because there is a limit to the number of allocations that we can make in Vulkan.
static constexpr VkDeviceSize cMinAllocSize = 512;
static constexpr VkDeviceSize cMaxAllocSize = 65536;
static constexpr VkDeviceSize cBlockSize = 524288;
struct MemoryKey
{
bool operator == (const MemoryKey &inRHS) const
{
return mSize == inRHS.mSize && mProperties == inRHS.mProperties;
}
VkDeviceSize mSize;
VkMemoryPropertyFlags mProperties;
};
JPH_MAKE_HASH_STRUCT(MemoryKey, MemoryKeyHasher, t.mProperties, t.mSize)
struct Memory
{
Ref<MemoryVK> mMemory;
VkDeviceSize mOffset;
};
using MemoryCache = UnorderedMap<MemoryKey, Array<Memory>, MemoryKeyHasher>;
MemoryCache mMemoryCache;
};
JPH_NAMESPACE_END
#endif // JPH_USE_VK

View File

@ -0,0 +1,44 @@
// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
// SPDX-License-Identifier: MIT
#pragma once
#include <Jolt/Core/StringTools.h>
#ifdef JPH_USE_VK
JPH_SUPPRESS_WARNINGS_STD_BEGIN
JPH_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic")
#include <vulkan/vulkan.h>
JPH_SUPPRESS_WARNINGS_STD_END
JPH_NAMESPACE_BEGIN
inline bool VKFailed(VkResult inResult)
{
if (inResult == VK_SUCCESS)
return false;
Trace("Vulkan call failed with error code: %d", (int)inResult);
JPH_ASSERT(false);
return true;
}
template <class Result>
inline bool VKFailed(VkResult inResult, Result &outResult)
{
if (inResult == VK_SUCCESS)
return false;
String error = StringFormat("Vulkan call failed with error code: %d", (int)inResult);
outResult.SetError(error);
JPH_ASSERT(false);
return true;
}
JPH_NAMESPACE_END
#endif // JPH_USE_VK