-
Notifications
You must be signed in to change notification settings - Fork 10
Description
Step 1 in transcoding to astc - figure out how feasible the 1-partition mode really is
def find_best_1p_partition_torch_batched(
blocks_batch: torch.Tensor,
modes_to_test: List[Tuple[int, int]],
endpoint_method: str = "minmax"
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
device = blocks_batch.device
batch_size = blocks_batch.shape[0]
num_modes = len(modes_to_test)
mode_color_states = torch.tensor([m[0] for m in modes_to_test], device=device).float().view(1, num_modes, 1, 1)
mode_weight_states = torch.tensor([m[1] for m in modes_to_test], device=device).float().view(1, num_modes, 1)
Ea_batch = torch.zeros(batch_size, 4, device=device)
Eb_batch = torch.zeros(batch_size, 4, device=device)
if endpoint_method == "minmax":
color_sums_batch = blocks_batch[:, :, :3].sum(dim=2) # (B, 16)
min_indices = torch.argmin(color_sums_batch, dim=1) # (B,)
max_indices = torch.argmax(color_sums_batch, dim=1) # (B,)
Ea_batch = blocks_batch[torch.arange(batch_size, device=device), min_indices] # (B, 4)
Eb_batch = blocks_batch[torch.arange(batch_size, device=device), max_indices] # (B, 4)
else: # pca
for i in range(batch_size):
Ea_batch[i], Eb_batch[i] = calculate_ideal_endpoints_pca(blocks_batch[i])
E_delta_batch = Eb_batch - Ea_batch
dot_delta_batch = dot(E_delta_batch, E_delta_batch).view(-1, 1) + 1e-9
proj = dot(blocks_batch - Ea_batch.unsqueeze(1), E_delta_batch.unsqueeze(1)) / dot_delta_batch
ideal_weights = torch.clamp(proj, 0, 1) # Shape: (B, 16)
# Need to unsqueeze Ea_batch and Eb_batch to (B, 1, 1, 4) for broadcasting with mode_color_states (1, num_modes, 1, 1)
q_Ea = quantize_torch(Ea_batch.unsqueeze(1).unsqueeze(2), mode_color_states) # (B, num_modes, 1, 4)
q_Eb = quantize_torch(Eb_batch.unsqueeze(1).unsqueeze(2), mode_color_states) # (B, num_modes, 1, 4)
q_weights = quantize_torch(ideal_weights.unsqueeze(1), mode_weight_states) # (B, num_modes, 16)
# Shape of reconstructed: (B, num_modes, 16, 4)
reconstructed_blocks_all_modes = q_Ea + q_weights.unsqueeze(-1) * (q_Eb - q_Ea)
reconstructed_blocks = reconstructed_blocks_all_modes[torch.arange(batch_size, device=device), 0, :, :] # Shape: (B, 16, 4)
return reconstructed_blocks
running this vectorized implementation on a T4 on colab:
Search complete in 0.0110 seconds.
Found best modes and MSEs for a batch of 129600 blocks. (1920 x 1080)
Original:
ASTC (4x4, p1, 12 ep_q, 12 cw_q):
With PCA end-point selection (better, slower at 65s, not vectorized on GPU yet since I don't want to write a cuda kernel just for prototyping):
With minmax (color intensity/luminance based proxy) end-point selection (less accurate, really really fast on the GPU at 0.01 seconds):
granted, this was all on pytorch, so we don't actually have a fully fused cuda kernel that can do this E2E that we would otherwise within Vulkan, but it seems like the answer is yes, it's dirty fast to do an approximate semi-accurate (like <1s for 1080p images) transcoding to astc by limiting ourselves to 1 partition
I am also considering implementing https://onlinelibrary.wiley.com/doi/abs/10.1111/cgf.13534, specifically as a CNN for the subproblem of partition index selection for p=2, which would drastically increase the total flops/block from a few thousand to 100k flops (.1ms per block), with the hope that we won't need to do this exhaustively for every block, only ones with high mse relative to the original block. Since single block encoding is so cheap, and seems to work well for most realistic textures (horribly for random noise, but we rarely seem them except as noise-maps, which aren't typically compressed in BC6/7), this opens a future avenue to do something like:
- For BC1/2/3 - transcode to astc 4x4 p=1 (since BC1/3 are equivalent to p=1)
- For BC4/5 - target an appropriate non-rgba8 format
- For BC6/7 - transcode to astc 4x4 p=1 with mse thresholds to try to use the CNN to compute a p=2 partition index (which will be slow, but may be worth it if there's a high fidelity texture that performs poorly on p=1)
The CNN will take a vector of dim 64, expand it through one conv hidden layer (maybe 128 width?), then another mlp (maybe 256), before expanding/decoding to the target dim of 1024 as the index classification. Training this even on a T4 should be no more than a few hours (with most of the time going towards preparing the dataset, since we need to have game screenshot tiles be classified to their optimal p=2 index somehow).
Afterwards, we can use the same minmax partition index algorithm, or PCA (since this is already a bit more involved).
Anyways, exciting stuff. I'm surprised that minmax P=1 mode works so well as is.
