🐛 fix: stream synchronization for XAttention estimate kernels in offload mode
- Wrap all compute kernels in select_blocks with compute_stream context (Pass 1 historical blocks, Pass 1 current chunk, Step 2 merge, Pass 2 historical blocks, Pass 2 current chunk, Step 4 block selection) - Fix K data mismatch between Pass 1 and Pass 2 by ensuring wait_slot_layer syncs with compute_stream where kernels actually run - Remove STRONG SYNC code from offload_engine.py (now handled by events) - Remove debug print statements and torch.save code - Consolidate fallback conditions in compute_with_xattn - Change default chunk_size from 16384 to 4096 for density alignment The bug caused Pass 1 and Pass 2 to see different K data from the same CPU block because compute kernels ran on default stream while wait_slot_layer only synced compute_stream. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -96,7 +96,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
self,
|
self,
|
||||||
threshold: float = 0.95, # High threshold for accuracy testing
|
threshold: float = 0.95, # High threshold for accuracy testing
|
||||||
stride: int = 8,
|
stride: int = 8,
|
||||||
chunk_size: int = 16384,
|
chunk_size: int = 4096, # Match offload Q chunk size for density alignment
|
||||||
block_size: int = 128,
|
block_size: int = 128,
|
||||||
samples_per_chunk: int = 128,
|
samples_per_chunk: int = 128,
|
||||||
use_triton: bool = True,
|
use_triton: bool = True,
|
||||||
@@ -289,9 +289,11 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
Returns:
|
Returns:
|
||||||
Attention output [total_q, num_heads, head_dim]
|
Attention output [total_q, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
# When block_tables is provided (paged KV cache / prefix cache),
|
# Fallback to flash attention when:
|
||||||
# fallback to flash_attn as XAttention expects contiguous K, V
|
# 1. block_tables provided (paged KV cache / prefix cache) - XAttention expects contiguous K, V
|
||||||
if block_tables is not None:
|
# 2. BSA kernel not available
|
||||||
|
# 3. xattn_estimate not available
|
||||||
|
if block_tables is not None or not BSA_AVAILABLE or not XATTN_AVAILABLE:
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
return flash_attn_varlen_func(
|
return flash_attn_varlen_func(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
@@ -304,32 +306,6 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
block_table=block_tables,
|
block_table=block_tables,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not BSA_AVAILABLE:
|
|
||||||
# Fallback to flash attention if BSA not available
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
return flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=max_seqlen_q,
|
|
||||||
max_seqlen_k=max_seqlen_k,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not XATTN_AVAILABLE:
|
|
||||||
# Fallback to flash attention if xattn not available
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
return flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=max_seqlen_q,
|
|
||||||
max_seqlen_k=max_seqlen_k,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
from nanovllm.ops.xattn import xattn_estimate
|
from nanovllm.ops.xattn import xattn_estimate
|
||||||
|
|
||||||
# Set DensityObserver mode on first layer
|
# Set DensityObserver mode on first layer
|
||||||
@@ -477,8 +453,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1]
|
causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1]
|
||||||
causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool))
|
causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool))
|
||||||
selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||||
logger.info(f"[DEBUG GPU-only Layer0] mask_shape={mask_trimmed.shape}, "
|
|
||||||
f"density={selected/causal_total:.6f}, selected={selected}, total={causal_total}")
|
|
||||||
DensityObserver.record(layer_id, mask_trimmed, causal=True)
|
DensityObserver.record(layer_id, mask_trimmed, causal=True)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@@ -633,98 +608,108 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
l_chunks = []
|
l_chunks = []
|
||||||
num_kv_chunks = num_historical_blocks + 1 # +1 for current chunk
|
num_kv_chunks = num_historical_blocks + 1 # +1 for current chunk
|
||||||
|
|
||||||
|
# Get compute_stream for all compute kernels (like attention computation)
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
with nvtx.range("xattn_estimate_pass1"):
|
with nvtx.range("xattn_estimate_pass1"):
|
||||||
slot = 0
|
slot = 0
|
||||||
|
|
||||||
# Process historical blocks (from CPU)
|
# Process historical blocks (from CPU)
|
||||||
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
||||||
# Load K from CPU
|
# Load K from CPU (on slot_transfer_stream)
|
||||||
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||||
|
# wait_slot_layer makes compute_stream wait for H2D transfer
|
||||||
offload_engine.wait_slot_layer(slot)
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
k_block = offload_engine.get_k_for_slot(slot) # [1, block_size, num_kv_heads, head_dim]
|
# All compute kernels run on compute_stream (like attention computation)
|
||||||
K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
|
with torch.cuda.stream(compute_stream):
|
||||||
|
k_block = offload_engine.get_k_for_slot(slot) # [1, block_size, num_kv_heads, head_dim]
|
||||||
|
K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
|
||||||
|
|
||||||
# GQA expansion
|
# GQA expansion
|
||||||
num_kv_heads = K_chunk.shape[1]
|
num_kv_heads = K_chunk.shape[1]
|
||||||
|
if num_heads != num_kv_heads:
|
||||||
|
num_groups = num_heads // num_kv_heads
|
||||||
|
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
||||||
|
|
||||||
|
# KV offset in reshaped space
|
||||||
|
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
|
||||||
|
|
||||||
|
# Compute raw attention scores
|
||||||
|
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||||
|
Q, K_chunk, self.stride,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
chunk_end=chunk_end,
|
||||||
|
is_causal=False, # K 不完整,不能在这里用 causal
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute partial stats (带 causal mask)
|
||||||
|
m_partial, l_partial = softmax_compute_partial_stats(
|
||||||
|
attn_weights_kv,
|
||||||
|
reshaped_block_size,
|
||||||
|
segment_size,
|
||||||
|
scale,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
kv_offset=kv_offset_reshaped,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
m_chunks.append(m_partial)
|
||||||
|
l_chunks.append(l_partial)
|
||||||
|
|
||||||
|
offload_engine.record_slot_compute_done(slot)
|
||||||
|
del attn_weights_kv
|
||||||
|
|
||||||
|
# Process current chunk K (already on GPU) on compute_stream
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# k: [seq_len, num_kv_heads, head_dim] -> [1, num_kv_heads, seq_len, head_dim]
|
||||||
|
K_current = k.unsqueeze(0).transpose(1, 2)
|
||||||
|
|
||||||
|
# GQA expansion for current chunk
|
||||||
|
num_kv_heads = K_current.shape[1]
|
||||||
if num_heads != num_kv_heads:
|
if num_heads != num_kv_heads:
|
||||||
num_groups = num_heads // num_kv_heads
|
num_groups = num_heads // num_kv_heads
|
||||||
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
K_current = K_current.repeat_interleave(num_groups, dim=1)
|
||||||
|
|
||||||
# KV offset in reshaped space
|
# Pad current K to alignment
|
||||||
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
|
curr_k_len = K_current.shape[2]
|
||||||
|
padded_curr_k_len = ((curr_k_len + alignment - 1) // alignment) * alignment
|
||||||
|
if padded_curr_k_len != curr_k_len:
|
||||||
|
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, padded_curr_k_len - curr_k_len), value=0)
|
||||||
|
|
||||||
# Compute raw attention scores
|
# KV offset for current chunk
|
||||||
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
kv_offset_current = num_historical_blocks * kv_chunk_reshaped
|
||||||
Q, K_chunk, self.stride,
|
|
||||||
|
# Compute attention scores for current chunk
|
||||||
|
attn_weights_curr = flat_group_gemm_fuse_reshape(
|
||||||
|
Q, K_current, self.stride,
|
||||||
chunk_start=chunk_start,
|
chunk_start=chunk_start,
|
||||||
chunk_end=chunk_end,
|
chunk_end=chunk_end,
|
||||||
is_causal=False, # K 不完整,不能在这里用 causal
|
is_causal=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute partial stats (带 causal mask)
|
# Compute partial stats for current chunk
|
||||||
m_partial, l_partial = softmax_compute_partial_stats(
|
m_partial_curr, l_partial_curr = softmax_compute_partial_stats(
|
||||||
attn_weights_kv,
|
attn_weights_curr,
|
||||||
reshaped_block_size,
|
reshaped_block_size,
|
||||||
segment_size,
|
segment_size,
|
||||||
scale,
|
scale,
|
||||||
chunk_start=chunk_start,
|
chunk_start=chunk_start,
|
||||||
kv_offset=kv_offset_reshaped,
|
kv_offset=kv_offset_current,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
m_chunks.append(m_partial)
|
m_chunks.append(m_partial_curr)
|
||||||
l_chunks.append(l_partial)
|
l_chunks.append(l_partial_curr)
|
||||||
|
|
||||||
offload_engine.record_slot_compute_done(slot)
|
del attn_weights_curr
|
||||||
del attn_weights_kv
|
|
||||||
|
|
||||||
# Process current chunk K (already on GPU)
|
|
||||||
# k: [seq_len, num_kv_heads, head_dim] -> [1, num_kv_heads, seq_len, head_dim]
|
|
||||||
K_current = k.unsqueeze(0).transpose(1, 2)
|
|
||||||
|
|
||||||
# GQA expansion for current chunk
|
|
||||||
num_kv_heads = K_current.shape[1]
|
|
||||||
if num_heads != num_kv_heads:
|
|
||||||
num_groups = num_heads // num_kv_heads
|
|
||||||
K_current = K_current.repeat_interleave(num_groups, dim=1)
|
|
||||||
|
|
||||||
# Pad current K to alignment
|
|
||||||
curr_k_len = K_current.shape[2]
|
|
||||||
padded_curr_k_len = ((curr_k_len + alignment - 1) // alignment) * alignment
|
|
||||||
if padded_curr_k_len != curr_k_len:
|
|
||||||
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, padded_curr_k_len - curr_k_len), value=0)
|
|
||||||
|
|
||||||
# KV offset for current chunk
|
|
||||||
kv_offset_current = num_historical_blocks * kv_chunk_reshaped
|
|
||||||
|
|
||||||
# Compute attention scores for current chunk
|
|
||||||
attn_weights_curr = flat_group_gemm_fuse_reshape(
|
|
||||||
Q, K_current, self.stride,
|
|
||||||
chunk_start=chunk_start,
|
|
||||||
chunk_end=chunk_end,
|
|
||||||
is_causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute partial stats for current chunk
|
|
||||||
m_partial_curr, l_partial_curr = softmax_compute_partial_stats(
|
|
||||||
attn_weights_curr,
|
|
||||||
reshaped_block_size,
|
|
||||||
segment_size,
|
|
||||||
scale,
|
|
||||||
chunk_start=chunk_start,
|
|
||||||
kv_offset=kv_offset_current,
|
|
||||||
is_causal=True,
|
|
||||||
)
|
|
||||||
m_chunks.append(m_partial_curr)
|
|
||||||
l_chunks.append(l_partial_curr)
|
|
||||||
del attn_weights_curr
|
|
||||||
|
|
||||||
# ================================================================
|
# ================================================================
|
||||||
# Step 2: Merge all partial stats
|
# Step 2: Merge all partial stats (on compute_stream)
|
||||||
# ================================================================
|
# ================================================================
|
||||||
with nvtx.range("xattn_estimate_merge"):
|
with torch.cuda.stream(compute_stream):
|
||||||
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
with nvtx.range("xattn_estimate_merge"):
|
||||||
del m_chunks, l_chunks
|
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||||
|
|
||||||
|
del m_chunks, l_chunks
|
||||||
|
|
||||||
# ================================================================
|
# ================================================================
|
||||||
# Step 3: Second pass - normalize and compute block sums
|
# Step 3: Second pass - normalize and compute block sums
|
||||||
@@ -736,30 +721,61 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
|
|
||||||
# Process historical blocks again
|
# Process historical blocks again
|
||||||
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
||||||
|
# Load K from CPU (on slot_transfer_stream)
|
||||||
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||||
|
# wait_slot_layer makes compute_stream wait for H2D transfer
|
||||||
offload_engine.wait_slot_layer(slot)
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
k_block = offload_engine.get_k_for_slot(slot)
|
# All compute kernels run on compute_stream
|
||||||
K_chunk = k_block.transpose(1, 2)
|
with torch.cuda.stream(compute_stream):
|
||||||
|
k_block = offload_engine.get_k_for_slot(slot)
|
||||||
|
K_chunk = k_block.transpose(1, 2)
|
||||||
|
|
||||||
num_kv_heads = K_chunk.shape[1]
|
num_kv_heads = K_chunk.shape[1]
|
||||||
if num_heads != num_kv_heads:
|
if num_heads != num_kv_heads:
|
||||||
num_groups = num_heads // num_kv_heads
|
num_groups = num_heads // num_kv_heads
|
||||||
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
||||||
|
|
||||||
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
|
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
|
||||||
|
|
||||||
# Recompute attention scores (trade-off: compute vs memory)
|
# Recompute attention scores (trade-off: compute vs memory)
|
||||||
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||||
Q, K_chunk, self.stride,
|
Q, K_chunk, self.stride,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
chunk_end=chunk_end,
|
||||||
|
is_causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize with global stats and compute block sums
|
||||||
|
block_sum_kv = softmax_normalize_and_block_sum(
|
||||||
|
attn_weights_kv,
|
||||||
|
m_global,
|
||||||
|
l_global,
|
||||||
|
reshaped_block_size,
|
||||||
|
segment_size,
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||||
|
scale=scale,
|
||||||
|
kv_offset=kv_offset_reshaped,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
attn_sum_per_kv.append(block_sum_kv)
|
||||||
|
|
||||||
|
offload_engine.record_slot_compute_done(slot)
|
||||||
|
del attn_weights_kv
|
||||||
|
|
||||||
|
# Process current chunk on compute_stream
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Recompute attention scores for current chunk
|
||||||
|
attn_weights_curr = flat_group_gemm_fuse_reshape(
|
||||||
|
Q, K_current, self.stride,
|
||||||
chunk_start=chunk_start,
|
chunk_start=chunk_start,
|
||||||
chunk_end=chunk_end,
|
chunk_end=chunk_end,
|
||||||
is_causal=False,
|
is_causal=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normalize with global stats and compute block sums
|
block_sum_curr = softmax_normalize_and_block_sum(
|
||||||
block_sum_kv = softmax_normalize_and_block_sum(
|
attn_weights_curr,
|
||||||
attn_weights_kv,
|
|
||||||
m_global,
|
m_global,
|
||||||
l_global,
|
l_global,
|
||||||
reshaped_block_size,
|
reshaped_block_size,
|
||||||
@@ -767,67 +783,42 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
chunk_start=chunk_start,
|
chunk_start=chunk_start,
|
||||||
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
kv_offset=kv_offset_reshaped,
|
kv_offset=kv_offset_current,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
attn_sum_per_kv.append(block_sum_kv)
|
attn_sum_per_kv.append(block_sum_curr)
|
||||||
|
del attn_weights_curr, K_current
|
||||||
offload_engine.record_slot_compute_done(slot)
|
|
||||||
del attn_weights_kv
|
|
||||||
|
|
||||||
# Process current chunk
|
|
||||||
# Recompute attention scores for current chunk
|
|
||||||
attn_weights_curr = flat_group_gemm_fuse_reshape(
|
|
||||||
Q, K_current, self.stride,
|
|
||||||
chunk_start=chunk_start,
|
|
||||||
chunk_end=chunk_end,
|
|
||||||
is_causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
block_sum_curr = softmax_normalize_and_block_sum(
|
|
||||||
attn_weights_curr,
|
|
||||||
m_global,
|
|
||||||
l_global,
|
|
||||||
reshaped_block_size,
|
|
||||||
segment_size,
|
|
||||||
chunk_start=chunk_start,
|
|
||||||
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
|
||||||
scale=scale,
|
|
||||||
kv_offset=kv_offset_current,
|
|
||||||
is_causal=True,
|
|
||||||
)
|
|
||||||
attn_sum_per_kv.append(block_sum_curr)
|
|
||||||
del attn_weights_curr, K_current
|
|
||||||
|
|
||||||
# ================================================================
|
# ================================================================
|
||||||
# Step 4: Concatenate block sums and select blocks
|
# Step 4: Concatenate block sums and select blocks (on compute_stream)
|
||||||
# ================================================================
|
# ================================================================
|
||||||
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
|
with torch.cuda.stream(compute_stream):
|
||||||
del attn_sum_per_kv, m_global, l_global
|
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
|
||||||
|
del attn_sum_per_kv, m_global, l_global
|
||||||
|
|
||||||
# Calculate q_block offset for find_blocks_chunked
|
# Calculate q_block offset for find_blocks_chunked
|
||||||
# This is the number of BSA blocks before Q in the full sequence
|
# This is the number of BSA blocks before Q in the full sequence
|
||||||
num_blocks_per_chunk = q_reshaped_len // reshaped_block_size
|
num_blocks_per_chunk = q_reshaped_len // reshaped_block_size
|
||||||
current_index = k_block_num - q_block_num # Q starts at this BSA block index
|
current_index = k_block_num - q_block_num # Q starts at this BSA block index
|
||||||
|
|
||||||
with nvtx.range("xattn_find_blocks"):
|
with nvtx.range("xattn_find_blocks"):
|
||||||
mask = find_blocks_chunked(
|
mask = find_blocks_chunked(
|
||||||
attn_sum_concat,
|
attn_sum_concat,
|
||||||
current_index=current_index,
|
current_index=current_index,
|
||||||
threshold=self.threshold,
|
threshold=self.threshold,
|
||||||
num_to_choose=None,
|
num_to_choose=None,
|
||||||
decoding=False,
|
decoding=False,
|
||||||
mode="prefill",
|
mode="prefill",
|
||||||
causal=True,
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply causal mask post-processing (same as xattn.py lines 1300-1306)
|
||||||
|
mask[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||||
|
torch.tril(torch.ones(q_block_num, q_block_num, dtype=torch.bool, device=mask.device), diagonal=0),
|
||||||
|
mask[:, :, -q_block_num:, -q_block_num:],
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply causal mask post-processing (same as xattn.py lines 1300-1306)
|
|
||||||
mask[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
|
||||||
torch.tril(torch.ones(q_block_num, q_block_num, dtype=torch.bool, device=mask.device), diagonal=0),
|
|
||||||
mask[:, :, -q_block_num:, -q_block_num:],
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ================================================================
|
# ================================================================
|
||||||
# Step 5: Record density (only on layer 0)
|
# Step 5: Record density (only on layer 0)
|
||||||
# ================================================================
|
# ================================================================
|
||||||
@@ -908,14 +899,6 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
if available_blocks and available_blocks[-1] not in selected_block_ids:
|
if available_blocks and available_blocks[-1] not in selected_block_ids:
|
||||||
selected_block_ids.append(available_blocks[-1])
|
selected_block_ids.append(available_blocks[-1])
|
||||||
|
|
||||||
# Record communication density
|
|
||||||
if available_blocks:
|
|
||||||
DensityObserver.record_comm_density(
|
|
||||||
layer_id,
|
|
||||||
selected_cpu_blocks=len(selected_block_ids),
|
|
||||||
total_cpu_blocks=len(available_blocks),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update statistics (only for layer 0 to avoid overcounting)
|
# Update statistics (only for layer 0 to avoid overcounting)
|
||||||
if layer_id == 0 and available_blocks:
|
if layer_id == 0 and available_blocks:
|
||||||
self._stats_total_available_blocks += len(available_blocks)
|
self._stats_total_available_blocks += len(available_blocks)
|
||||||
|
|||||||
Reference in New Issue
Block a user