コンピュートシェーダで実行する際は…

こんにちわ,Pocolです。
最近、最適化の話とかを見るのがちょっとハマっています。

NVIDIAがthread-group ID swizzlingという最適化テクニックについての記事を投稿しています。

https://developer.nvidia.com/blog/optimizing-compute-shaders-for-l2-locality-using-thread-group-id-swizzling/

L2キャッシュを再利用できるようにアクセスパターンを変えることにより最適化を行うテクニックのようです。
2Dフルスクリーンのコンピュートシェーダを用いるものに重要となるテクニックだそうで,ポストプロセスやスクリーンスペース系の技法を実装する際には重宝しそうです。

上記のテクニックはGDC 2019で紹介されているもので,バトルフィールド5ではRTX 2080(1440p)で0.75msの改善があったと報告されています(SetStablePowerState(TRUE)での動作だそうです)。
また,GDC 2019で紹介したソースコードにバグがあり,X方向(N)に起動するスレッドグループの数の倍数である場合にのみ動作するものだったそうです。
修正したソースコードについても提示がされています。

上記の記事のHLSLコードが実際動くのか,コピってみて試したのがだめでした。
NVIDIAのWebページの方では,いくつかHTMLの変換ミスがあるっぽくてアスタリスク(*)が無くなったりしていて,そのままコピペしてもビルドエラーになるので注意してください。
そこで,D3D11で動くように実装を修正してみました。下記のような感じです。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// スレッドサイズ.
#define THREAD_SIZE (8)
 
// Shader Model 5系かどうか?
#define IS_SM5 (1)
 
///////////////////////////////////////////////////////////////////////////////
// ColorFilterParam structure
///////////////////////////////////////////////////////////////////////////////
cbuffer CbColorFilter : register(b0)
{
    uint2       DipsatchArgs : packoffset(c0);   // Dispatch()メソッドに渡した引数.
    float4x4    ColorMatrix  : packoffset(c1);   // カラー変換行列.
};
 
//-----------------------------------------------------------------------------
// Resources.
//-----------------------------------------------------------------------------
Texture2D<float4>   Input   : register(t0);
RWTexture2D<float4> Output  : register(u0);
 
 
//-----------------------------------------------------------------------------
//! @brief      スレッドグループのタイリングを行う.
//!
//! @param[in]      dispatchGridDim     Dipatch(X, Y, Z)で渡した(X, Y)の値.
//! @param[in]      groupId             グループID
//! @param[in]      groupTheradId       グループスレッドID.
//! @return     スレッドIDを返却する.
//-----------------------------------------------------------------------------
uint2 CalcSwizzledThreaId(uint2 dispatchDim, uint2 groupId, uint2 groupThreadId)
{
    // "CTA" (Cooperative Thread Array) == Thread Group in DirectX terminology
    const uint2 CTA_Dim = uint2(THREAD_SIZE, THREAD_SIZE);
    const uint N = 16; // 16 スレッドグループで起動.
 
    // 1タイル内のスレッドグループの総数.
    uint number_of_CTAs_in_a_perfect_tile = N * (dispatchDim.y);
 
    // 考えうる完全なタイルの数.
    uint number_of_perfect_tiles = dispatchDim.x / N;
 
    // 完全なタイルにおけるスレッドグループの総数.
    uint total_CTAs_in_all_perfect_tiles = number_of_perfect_tiles * N * dispatchDim.y - 1;
    uint threadGroupIDFlattened = dispatchDim.x * groupId.y + groupId.x;
 
    // 現在のスレッドグループからタイルIDへのマッピング.
    uint tile_ID_of_current_CTA = threadGroupIDFlattened / number_of_CTAs_in_a_perfect_tile;
    uint local_CTA_ID_within_current_tile = threadGroupIDFlattened % number_of_CTAs_in_a_perfect_tile;
 
    uint local_CTA_ID_y_within_current_tile = local_CTA_ID_within_current_tile / N;
    uint local_CTA_ID_x_within_current_tile = local_CTA_ID_within_current_tile % N;
  
    if (total_CTAs_in_all_perfect_tiles < threadGroupIDFlattened)
    {
        // 最後のタイルに不完全な次元があり、最後のタイルからのCTAが起動された場合にのみ実行されるパス.
        uint x_dimension_of_last_tile = dispatchDim.x % N;
    #if IS_SM5
        // SM5.0だとコンパイルエラーになるので対策.
        if (x_dimension_of_last_tile > 0)
        {
            local_CTA_ID_y_within_current_tile = local_CTA_ID_within_current_tile / x_dimension_of_last_tile;
            local_CTA_ID_x_within_current_tile = local_CTA_ID_within_current_tile % x_dimension_of_last_tile;
        }
    #else
        local_CTA_ID_y_within_current_tile = local_CTA_ID_within_current_tile / x_dimension_of_last_tile;
        local_CTA_ID_x_within_current_tile = local_CTA_ID_within_current_tile % x_dimension_of_last_tile;
    #endif
    }
 
    uint swizzledThreadGroupIDFlattened = tile_ID_of_current_CTA * N
      + local_CTA_ID_y_within_current_tile * dispatchDim.x
      + local_CTA_ID_x_within_current_tile;
 
    uint2 swizzledThreadGroupID;
    swizzledThreadGroupID.y = swizzledThreadGroupIDFlattened / dispatchDim.x;
    swizzledThreadGroupID.x = swizzledThreadGroupIDFlattened % dispatchDim.x;
 
    uint2 swizzledThreadID;
    swizzledThreadID.x = CTA_Dim.x * swizzledThreadGroupID.x + groupThreadId.x;
    swizzledThreadID.y = CTA_Dim.y * swizzledThreadGroupID.y + groupThreadId.y;
 
    return swizzledThreadID;
}
 
 
//-----------------------------------------------------------------------------
//      メインエントリーポイントです.
//-----------------------------------------------------------------------------
[numthreads(THREAD_SIZE, THREAD_SIZE, 1)]
void main
(
    uint3 groupId       : SV_GroupID,
    uint3 groupThreadId : SV_GroupThreadID
)
{
    uint2 id = CalcSwizzledThreaId(DipsatchArgs, groupId.xy, groupThreadId.xy);
    Output[id] = mul(ColorMatrix, Input[id]);
}

基本的には,いったんフラットなID(つまり通し番号)にして,そこから再算出するみたいな計算しているみたいです。
cpp側は下記のような感じです。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// カラーフィルタ実行.
{
    auto x = (m_TextureWidth  + m_ThreadCountX - 1) / m_ThreadCountX; // m_ThreadCountX = THREAD_SIZE. シェーダリフレクションで取得.
    auto y = (m_TextureHeight + m_ThreadCountY - 1) / m_ThreadCountY; // m_ThreadCountY = THREAD_SIZE. シェーダリフレクションで取得.
 
    auto pCB = m_CB.GetBuffer();
    CbColorFilter res = {};
    res.ThreadX = x;
    res.ThreadY = y;
    res.ColorMatrix = asdx::Matrix::CreateIdentity();
 
    m_pDeviceContext->UpdateSubresource(pCB, 0, nullptr, &res, 0, 0);
 
    auto pSRV = m_Texture.GetSRV();
    auto pUAV = m_ComputeUAV.GetPtr();
    m_CS.Bind(m_pDeviceContext.GetPtr());
    m_pDeviceContext->CSSetConstantBuffers(0, 1, &pCB);
    m_pDeviceContext->CSSetShaderResources(0, 1, &pSRV);
    m_pDeviceContext->CSSetUnorderedAccessViews(0, 1, &pUAV, nullptr);
    m_pDeviceContext->Dispatch(x, y, 1);
 
    ID3D11ShaderResourceView* pNullSRV[1] = {};
    ID3D11UnorderedAccessView* pNullUAV[1] = {};
    m_pDeviceContext->CSSetShaderResources(0, 1, pNullSRV);
    m_pDeviceContext->CSSetUnorderedAccessViews(0, 1, pNullUAV, nullptr);
    m_CS.UnBind(m_pDeviceContext.GetPtr());
}

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

This site uses Akismet to reduce spam. Learn how your comment data is processed.