コンピュートシェーダで実行する際は…

こんにちわ,Pocolです。
最近、最適化の話とかを見るのがちょっとハマっています。

NVIDIAがthread-group ID swizzlingという最適化テクニックについての記事を投稿しています。

https://developer.nvidia.com/blog/optimizing-compute-shaders-for-l2-locality-using-thread-group-id-swizzling/

L2キャッシュを再利用できるようにアクセスパターンを変えることにより最適化を行うテクニックのようです。
2Dフルスクリーンのコンピュートシェーダを用いるものに重要となるテクニックだそうで,ポストプロセスやスクリーンスペース系の技法を実装する際には重宝しそうです。

上記のテクニックはGDC 2019で紹介されているもので,バトルフィールド5ではRTX 2080(1440p)で0.75msの改善があったと報告されています(SetStablePowerState(TRUE)での動作だそうです)。
また,GDC 2019で紹介したソースコードにバグがあり,X方向(N)に起動するスレッドグループの数の倍数である場合にのみ動作するものだったそうです。
修正したソースコードについても提示がされています。

上記の記事のHLSLコードが実際動くのか,コピってみて試したのがだめでした。
NVIDIAのWebページの方では,いくつかHTMLの変換ミスがあるっぽくてアスタリスク(*)が無くなったりしていて,そのままコピペしてもビルドエラーになるので注意してください。
そこで,D3D11で動くように実装を修正してみました。下記のような感じです。

// スレッドサイズ.
#define THREAD_SIZE (8)

// Shader Model 5系かどうか?
#define IS_SM5 (1)

///////////////////////////////////////////////////////////////////////////////
// ColorFilterParam structure
///////////////////////////////////////////////////////////////////////////////
cbuffer CbColorFilter : register(b0)
{
    uint2       DipsatchArgs : packoffset(c0);   // Dispatch()メソッドに渡した引数.
    float4x4    ColorMatrix  : packoffset(c1);   // カラー変換行列.
};

//-----------------------------------------------------------------------------
// Resources.
//-----------------------------------------------------------------------------
Texture2D<float4>   Input   : register(t0);
RWTexture2D<float4> Output  : register(u0);


//-----------------------------------------------------------------------------
//! @brief      スレッドグループのタイリングを行う.
//!
//! @param[in]      dispatchGridDim     Dipatch(X, Y, Z)で渡した(X, Y)の値.
//! @param[in]      groupId             グループID
//! @param[in]      groupTheradId       グループスレッドID.
//! @return     スレッドIDを返却する.
//-----------------------------------------------------------------------------
uint2 CalcSwizzledThreaId(uint2 dispatchDim, uint2 groupId, uint2 groupThreadId)
{
    // "CTA" (Cooperative Thread Array) == Thread Group in DirectX terminology
    const uint2 CTA_Dim = uint2(THREAD_SIZE, THREAD_SIZE);
    const uint N = 16; // 16 スレッドグループで起動.

    // 1タイル内のスレッドグループの総数.
    uint number_of_CTAs_in_a_perfect_tile = N * (dispatchDim.y);

    // 考えうる完全なタイルの数.
    uint number_of_perfect_tiles = dispatchDim.x / N;

    // 完全なタイルにおけるスレッドグループの総数.
    uint total_CTAs_in_all_perfect_tiles = number_of_perfect_tiles * N * dispatchDim.y - 1;
    uint threadGroupIDFlattened = dispatchDim.x * groupId.y + groupId.x;

    // 現在のスレッドグループからタイルIDへのマッピング.
    uint tile_ID_of_current_CTA = threadGroupIDFlattened / number_of_CTAs_in_a_perfect_tile;
    uint local_CTA_ID_within_current_tile = threadGroupIDFlattened % number_of_CTAs_in_a_perfect_tile;

    uint local_CTA_ID_y_within_current_tile = local_CTA_ID_within_current_tile / N;
    uint local_CTA_ID_x_within_current_tile = local_CTA_ID_within_current_tile % N;
 
    if (total_CTAs_in_all_perfect_tiles < threadGroupIDFlattened)
    {
        // 最後のタイルに不完全な次元があり、最後のタイルからのCTAが起動された場合にのみ実行されるパス.
        uint x_dimension_of_last_tile = dispatchDim.x % N;
    #if IS_SM5
        // SM5.0だとコンパイルエラーになるので対策.
        if (x_dimension_of_last_tile > 0)
        {
            local_CTA_ID_y_within_current_tile = local_CTA_ID_within_current_tile / x_dimension_of_last_tile;
            local_CTA_ID_x_within_current_tile = local_CTA_ID_within_current_tile % x_dimension_of_last_tile;
        }
    #else
        local_CTA_ID_y_within_current_tile = local_CTA_ID_within_current_tile / x_dimension_of_last_tile;
        local_CTA_ID_x_within_current_tile = local_CTA_ID_within_current_tile % x_dimension_of_last_tile;
    #endif
    }

    uint swizzledThreadGroupIDFlattened = tile_ID_of_current_CTA * N
      + local_CTA_ID_y_within_current_tile * dispatchDim.x
      + local_CTA_ID_x_within_current_tile;

    uint2 swizzledThreadGroupID;
    swizzledThreadGroupID.y = swizzledThreadGroupIDFlattened / dispatchDim.x;
    swizzledThreadGroupID.x = swizzledThreadGroupIDFlattened % dispatchDim.x;

    uint2 swizzledThreadID;
    swizzledThreadID.x = CTA_Dim.x * swizzledThreadGroupID.x + groupThreadId.x;
    swizzledThreadID.y = CTA_Dim.y * swizzledThreadGroupID.y + groupThreadId.y;

    return swizzledThreadID;
}


//-----------------------------------------------------------------------------
//      メインエントリーポイントです.
//-----------------------------------------------------------------------------
[numthreads(THREAD_SIZE, THREAD_SIZE, 1)]
void main
(
    uint3 groupId       : SV_GroupID,
    uint3 groupThreadId : SV_GroupThreadID
)
{
    uint2 id = CalcSwizzledThreaId(DipsatchArgs, groupId.xy, groupThreadId.xy);
    Output[id] = mul(ColorMatrix, Input[id]);
}

基本的には,いったんフラットなID(つまり通し番号)にして,そこから再算出するみたいな計算しているみたいです。
cpp側は下記のような感じです。

    // カラーフィルタ実行.
    {
        auto x = (m_TextureWidth  + m_ThreadCountX - 1) / m_ThreadCountX; // m_ThreadCountX = THREAD_SIZE. シェーダリフレクションで取得.
        auto y = (m_TextureHeight + m_ThreadCountY - 1) / m_ThreadCountY; // m_ThreadCountY = THREAD_SIZE. シェーダリフレクションで取得.

        auto pCB = m_CB.GetBuffer();
        CbColorFilter res = {};
        res.ThreadX = x;
        res.ThreadY = y;
        res.ColorMatrix = asdx::Matrix::CreateIdentity();

        m_pDeviceContext->UpdateSubresource(pCB, 0, nullptr, &res, 0, 0);

        auto pSRV = m_Texture.GetSRV();
        auto pUAV = m_ComputeUAV.GetPtr();
        m_CS.Bind(m_pDeviceContext.GetPtr());
        m_pDeviceContext->CSSetConstantBuffers(0, 1, &pCB);
        m_pDeviceContext->CSSetShaderResources(0, 1, &pSRV);
        m_pDeviceContext->CSSetUnorderedAccessViews(0, 1, &pUAV, nullptr);
        m_pDeviceContext->Dispatch(x, y, 1);

        ID3D11ShaderResourceView* pNullSRV[1] = {};
        ID3D11UnorderedAccessView* pNullUAV[1] = {};
        m_pDeviceContext->CSSetShaderResources(0, 1, pNullSRV);
        m_pDeviceContext->CSSetUnorderedAccessViews(0, 1, pNullUAV, nullptr);
        m_CS.UnBind(m_pDeviceContext.GetPtr());
    }

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください