Bilateral Upsampling

こんにちわ,Pocolです。
今日はバイラテラルアップサンプリングについてメモをしておこうと思います。
パフォーマンスを稼ぐために,低解像度で描画しておき,それを元解像度に戻したいという場面が,ゲームグラフィックスでは多々出てきます。具体的には,SSAOやSSRなどの計算です。
ただ単にバイリニア補間で元解像度に戻してしまうとエッジ部分などでアーティファクトが発生してしまうことがあります。
こうしたアーティファクトを避けるために使われる手法の中の一つとして,Bilateral Upsamplingがあります。

通常のバイリニア補間は4点から計算を行います。

バイラテラルアップサンプリングは,法線と深度によってバイリニアウェイトを修正します。サンプルは以下のように,バイリニアの重み,法線の類似度による重み,深度の類似度による重みの3つによって重みづけされます。

バイリニアの重みは以下です。

法線の重みは次のように求めます。

深度の重みは次のように求めます。

以上から求められた重みを使ってサンプルを重みづけします。下図の通りです。

実装例ですが,もんしょさんが「DirectXの話 第121回 Bilateral Upsampling」の記事にてサンプルコードをアップしてくださっています。有難いです。
シェーダコードを抜粋すると下記の通りです。

float4 RenderUpsamplingPS( OutputVS inPixel ) : SV_TARGET
{
	const float2 kScreenSize = g_ScreenParam.xy * 2.0;
	const float2 kScreenHalfSize = g_ScreenParam.xy;
	const float4 kBilinearWeights[4] =
	{
		float4( 9.0/16.0, 3.0/16.0, 3.0/16.0, 1.0/16.0 ),
		float4( 3.0/16.0, 9.0/16.0, 1.0/16.0, 3.0/16.0 ),
		float4( 3.0/16.0, 1.0/16.0, 9.0/16.0, 3.0/16.0 ),
		float4( 1.0/16.0, 3.0/16.0, 3.0/16.0, 9.0/16.0 )
	};

	// Hi-Resピクセルのインデックスを求める
	int2 hiResUV = (int2)(inPixel.texCoord0 * kScreenSize + float2(0.1, 0.1));
	int hiResIndex = (1 - (hiResUV.y & 0x01)) * 2 + (1 - (hiResUV.x & 0x01));
	float4 hiResND = texNormalDepth.Load( int3(hiResUV, 0), int2(0, 0) );

	// Low-Resから4ピクセルの法線・深度を求める
	int2 lowResUV = (int2)(inPixel.texCoord0 * kScreenHalfSize.xy + float2(0.1, 0.1));
	float4 lowResND[4];
	float lowResAO[4];
	switch (hiResIndex)
	{
	case 0:
		lowResND[0] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(0, 0) );
		lowResND[1] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(1, 0) );
		lowResND[2] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(0, 1) );
		lowResND[3] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(1, 1) );
		lowResAO[0] = texHDAO.Load( int3(lowResUV, 0), int2(0, 0) ).r;
		lowResAO[1] = texHDAO.Load( int3(lowResUV, 0), int2(1, 0) ).r;
		lowResAO[2] = texHDAO.Load( int3(lowResUV, 0), int2(0, 1) ).r;
		lowResAO[3] = texHDAO.Load( int3(lowResUV, 0), int2(1, 1) ).r;
		break;
	case 1:
		lowResND[0] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(-1, 0) );
		lowResND[1] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(0, 0) );
		lowResND[2] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(-1, 1) );
		lowResND[3] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(0, 1) );
		lowResAO[0] = texHDAO.Load( int3(lowResUV, 0), int2(-1, 0) ).r;
		lowResAO[1] = texHDAO.Load( int3(lowResUV, 0), int2(0, 0) ).r;
		lowResAO[2] = texHDAO.Load( int3(lowResUV, 0), int2(-1, 1) ).r;
		lowResAO[3] = texHDAO.Load( int3(lowResUV, 0), int2(0, 1) ).r;
		break;
	case 2:
		lowResND[0] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(0, -1) );
		lowResND[1] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(1, -1) );
		lowResND[2] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(0, 0) );
		lowResND[3] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(1, 0) );
		lowResAO[0] = texHDAO.Load( int3(lowResUV, 0), int2(0, -1) ).r;
		lowResAO[1] = texHDAO.Load( int3(lowResUV, 0), int2(1, -1) ).r;
		lowResAO[2] = texHDAO.Load( int3(lowResUV, 0), int2(0, 0) ).r;
		lowResAO[3] = texHDAO.Load( int3(lowResUV, 0), int2(1, 0) ).r;
		break;
	case 3:
		lowResND[0] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(-1, -1) );
		lowResND[1] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(0, -1) );
		lowResND[2] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(-1, 0) );
		lowResND[3] = texHalfNormalDepth.Load( int3(lowResUV, 0), int2(0, 0) );
		lowResAO[0] = texHDAO.Load( int3(lowResUV, 0), int2(-1, -1) ).r;
		lowResAO[1] = texHDAO.Load( int3(lowResUV, 0), int2(0, -1) ).r;
		lowResAO[2] = texHDAO.Load( int3(lowResUV, 0), int2(-1, 0) ).r;
		lowResAO[3] = texHDAO.Load( int3(lowResUV, 0), int2(0, 0) ).r;
		break;
	}

	// 法線のウェイトを求める
	float totalWeight = 0.0;
	float ao = 0.0;
	for( int i = 0; i < 4; ++i )
	{
		// 法線のウェイトを求める
		float normalWeight = dot( lowResND[i].xyz, hiResND.xyz );
		normalWeight = pow( saturate(normalWeight), 32.0 );

		// 深度のウェイトを求める
		float depthDiff = hiResND.w - lowResND[i].w;
		float depthWeight = 1.0 / (1.0 + abs(depthDiff));

		// 総合する
		float weight = normalWeight * depthWeight * kBilinearWeights[hiResIndex][i];
		totalWeight += weight;
		ao += lowResAO[i] * weight;
	}

	ao /= totalWeight;

	return float4(ao, ao, ao, 1);
}

…ということで,Bilateral Upsamplingの話でした。
もしかしたら,Quad系のWaveIntrinsics使って実装した方がナウいかもしれないですね(※試してないので,出来なかったらごめんなさい)。

※追記
Quad Intrinsics使って実装できました。
WaveGetLaneIndex() % 4でhiResIndexを算出します。一度現在位置での,lowResNDとlowResAOを先頭の方でサンプリングしておき,あとはループでQuadReadLaneAt(lowResND, i)と QuadReadLaneAt(lowResAO, i)で,処理対象を持ってきます。これでswitchケース分が丸っとなくせるのと,テクスチャフェッチ回数が減らせます。