こんにちわ,Pocolです。
最近、最適化の話とかを見るのがちょっとハマっています。
NVIDIAがthread-group ID swizzlingという最適化テクニックについての記事を投稿しています。
L2キャッシュを再利用できるようにアクセスパターンを変えることにより最適化を行うテクニックのようです。
2Dフルスクリーンのコンピュートシェーダを用いるものに重要となるテクニックだそうで,ポストプロセスやスクリーンスペース系の技法を実装する際には重宝しそうです。
上記のテクニックはGDC 2019で紹介されているもので,バトルフィールド5ではRTX 2080(1440p)で0.75msの改善があったと報告されています(SetStablePowerState(TRUE)での動作だそうです)。
また,GDC 2019で紹介したソースコードにバグがあり,X方向(N)に起動するスレッドグループの数の倍数である場合にのみ動作するものだったそうです。
修正したソースコードについても提示がされています。
上記の記事のHLSLコードが実際動くのか,コピってみて試したのがだめでした。
NVIDIAのWebページの方では,いくつかHTMLの変換ミスがあるっぽくてアスタリスク(*)が無くなったりしていて,そのままコピペしてもビルドエラーになるので注意してください。
そこで,D3D11で動くように実装を修正してみました。下記のような感じです。
// スレッドサイズ. #define THREAD_SIZE (8) // Shader Model 5系かどうか? #define IS_SM5 (1) /////////////////////////////////////////////////////////////////////////////// // ColorFilterParam structure /////////////////////////////////////////////////////////////////////////////// cbuffer CbColorFilter : register(b0) { uint2 DipsatchArgs : packoffset(c0); // Dispatch()メソッドに渡した引数. float4x4 ColorMatrix : packoffset(c1); // カラー変換行列. }; //----------------------------------------------------------------------------- // Resources. //----------------------------------------------------------------------------- Texture2D<float4> Input : register(t0); RWTexture2D<float4> Output : register(u0); //----------------------------------------------------------------------------- //! @brief スレッドグループのタイリングを行う. //! //! @param[in] dispatchGridDim Dipatch(X, Y, Z)で渡した(X, Y)の値. //! @param[in] groupId グループID //! @param[in] groupTheradId グループスレッドID. //! @return スレッドIDを返却する. //----------------------------------------------------------------------------- uint2 CalcSwizzledThreaId(uint2 dispatchDim, uint2 groupId, uint2 groupThreadId) { // "CTA" (Cooperative Thread Array) == Thread Group in DirectX terminology const uint2 CTA_Dim = uint2(THREAD_SIZE, THREAD_SIZE); const uint N = 16; // 16 スレッドグループで起動. // 1タイル内のスレッドグループの総数. uint number_of_CTAs_in_a_perfect_tile = N * (dispatchDim.y); // 考えうる完全なタイルの数. uint number_of_perfect_tiles = dispatchDim.x / N; // 完全なタイルにおけるスレッドグループの総数. uint total_CTAs_in_all_perfect_tiles = number_of_perfect_tiles * N * dispatchDim.y - 1; uint threadGroupIDFlattened = dispatchDim.x * groupId.y + groupId.x; // 現在のスレッドグループからタイルIDへのマッピング. uint tile_ID_of_current_CTA = threadGroupIDFlattened / number_of_CTAs_in_a_perfect_tile; uint local_CTA_ID_within_current_tile = threadGroupIDFlattened % number_of_CTAs_in_a_perfect_tile; uint local_CTA_ID_y_within_current_tile = local_CTA_ID_within_current_tile / N; uint local_CTA_ID_x_within_current_tile = local_CTA_ID_within_current_tile % N; if (total_CTAs_in_all_perfect_tiles < threadGroupIDFlattened) { // 最後のタイルに不完全な次元があり、最後のタイルからのCTAが起動された場合にのみ実行されるパス. uint x_dimension_of_last_tile = dispatchDim.x % N; #if IS_SM5 // SM5.0だとコンパイルエラーになるので対策. if (x_dimension_of_last_tile > 0) { local_CTA_ID_y_within_current_tile = local_CTA_ID_within_current_tile / x_dimension_of_last_tile; local_CTA_ID_x_within_current_tile = local_CTA_ID_within_current_tile % x_dimension_of_last_tile; } #else local_CTA_ID_y_within_current_tile = local_CTA_ID_within_current_tile / x_dimension_of_last_tile; local_CTA_ID_x_within_current_tile = local_CTA_ID_within_current_tile % x_dimension_of_last_tile; #endif } uint swizzledThreadGroupIDFlattened = tile_ID_of_current_CTA * N + local_CTA_ID_y_within_current_tile * dispatchDim.x + local_CTA_ID_x_within_current_tile; uint2 swizzledThreadGroupID; swizzledThreadGroupID.y = swizzledThreadGroupIDFlattened / dispatchDim.x; swizzledThreadGroupID.x = swizzledThreadGroupIDFlattened % dispatchDim.x; uint2 swizzledThreadID; swizzledThreadID.x = CTA_Dim.x * swizzledThreadGroupID.x + groupThreadId.x; swizzledThreadID.y = CTA_Dim.y * swizzledThreadGroupID.y + groupThreadId.y; return swizzledThreadID; } //----------------------------------------------------------------------------- // メインエントリーポイントです. //----------------------------------------------------------------------------- [numthreads(THREAD_SIZE, THREAD_SIZE, 1)] void main ( uint3 groupId : SV_GroupID, uint3 groupThreadId : SV_GroupThreadID ) { uint2 id = CalcSwizzledThreaId(DipsatchArgs, groupId.xy, groupThreadId.xy); Output[id] = mul(ColorMatrix, Input[id]); }
基本的には,いったんフラットなID(つまり通し番号)にして,そこから再算出するみたいな計算しているみたいです。
cpp側は下記のような感じです。
// カラーフィルタ実行. { auto x = (m_TextureWidth + m_ThreadCountX - 1) / m_ThreadCountX; // m_ThreadCountX = THREAD_SIZE. シェーダリフレクションで取得. auto y = (m_TextureHeight + m_ThreadCountY - 1) / m_ThreadCountY; // m_ThreadCountY = THREAD_SIZE. シェーダリフレクションで取得. auto pCB = m_CB.GetBuffer(); CbColorFilter res = {}; res.ThreadX = x; res.ThreadY = y; res.ColorMatrix = asdx::Matrix::CreateIdentity(); m_pDeviceContext->UpdateSubresource(pCB, 0, nullptr, &res, 0, 0); auto pSRV = m_Texture.GetSRV(); auto pUAV = m_ComputeUAV.GetPtr(); m_CS.Bind(m_pDeviceContext.GetPtr()); m_pDeviceContext->CSSetConstantBuffers(0, 1, &pCB); m_pDeviceContext->CSSetShaderResources(0, 1, &pSRV); m_pDeviceContext->CSSetUnorderedAccessViews(0, 1, &pUAV, nullptr); m_pDeviceContext->Dispatch(x, y, 1); ID3D11ShaderResourceView* pNullSRV[1] = {}; ID3D11UnorderedAccessView* pNullUAV[1] = {}; m_pDeviceContext->CSSetShaderResources(0, 1, pNullSRV); m_pDeviceContext->CSSetUnorderedAccessViews(0, 1, pNullUAV, nullptr); m_CS.UnBind(m_pDeviceContext.GetPtr()); }