b9000
hexagon: hmx flash attention (#22347)
hmx: extract shared interleave headers and unify matmul batched
hmx: add HMX-accelerated flash attention for prefill
hmx: replace asm wrappers with Q6_ intrinsics in hmx-utils.h
Switches three single-instruction helpers from inline asm to the matching Q6_ intrinsics, matching the style established by aizip f8737609a and used by the upstream PR #21554 hmx-matmul-ops.c rewrite:
hmx_set_output_scales asm "bias=mxmem2" -> Q6_bias_mxmem2_A hmx_load_tile_pair_fp16 asm packet -> Q6_activation_hf_mxmem_RR + Q6_weight_hf_mxmem_RR hmx_consume_accumulator_fp16 asm "mxmem=acc" -> Q6_mxmem_AR_after_hf
hmx_load_tiles_fp16 stays on inline asm: it uses ":deep" activation streaming, and the mixed Q6_activation_hf_mxmem_RR_deep + non-deep Q6_weight_hf_mxmem_RR pair fails the HMX backend constraint check ("activate weight pair (1) exceeds limit (1)"). The asm bundle keeps both halves in one VLIW packet and avoids the diagnostic.
Functionally equivalent — same instructions emitted; the Q6_ intrinsics just give the compiler more visibility for scheduling.
hmx: drop the duplicate interleave_fp16_weight_chunk_to_tiles
hmx: apply upstream optimization to hmx-flash-attn-ops.c apply restrict, __builtin_assume, and pointer accumulation to the three HMX workers (qk_dot, o_update, o_norm) and the matching inline HMX loops in op_hmx_flash_attn_ext.
hmx: unify interleave helper
hmx: multi-thread Q load / O store and enable prefill FA dispatch
Extract inline Q-load and O-store loops into worker_pool-parallel helpers (fa_phase_q_load, fa_phase_o_store) so HVX threads split the F32↔F16 conversion work across row ranges. Also relax the softmax threading gate from n_row_vec_cnt >= n_threads to >= 2, which was unnecessarily forcing single-thread fallback when n_rows_g < 512.
On the dispatch side, remove the ne[2] != 1 guard that blocked multi-head (prefill) FA from reaching the HTP backend — GQA is already handled internally by both the HMX and HVX flash-attention paths.
hmx: relax matmul pipeline gate to cover k > n shapes (e.g. FFN_down)
hmx: optimize FA softmax mask phase (no-ALiBi fast path + GQA dedup)
hmx: Add an asm memory clobber at the phase boundary to prevent reorder bug
[experimental]: fp16 softmax (EXP2_HF) to accelerate fa
Bake log2(e) into qk_scale and use hvx_exp2_hf directly for P and m_diff (base-2 consistent, matches htp-ops-lib). ~22 ALU ops for 64 lanes vs ~44 for the F32 round-trip path.
hmx flash-attn: refine cost model coefficients based on profiling data
hmx flash-attn: replace asm clobber with targeted volatile reads on vtcm_d_tiles
hmx flash-attn: fix prefill correctness (dst indexing, softmax reduce, V stride)
hmx flash-attn: fix p_tiles dual-tile OOB race; enable MT + pipeline
hmx flash-attn: preserve additive mask bias in no-ALiBi fast path
The no-ALiBi fast path (max_bias==0) was skipping mask add entirely on the assumption that mask values are only {0, -inf}. This is wrong when the mask carries additive positional bias — those terms were silently dropped. Keep the slope-mul skip (slope≡1.0) but add mask back so the bias survives; vmux still clamps below -16 to -inf.
Also add HMX FA coverage to test-backend-ops: prefill shapes (nb=64, nb=32) × {mask on/off} × {ALiBi on/off} × {softcap on/off}, F16 KV, hs ∈ {64, 128}.
- hmx: fix softcap+EXP2_HF interaction, tighten matmul pipeline gate, add FA tests
- flash-attn: when EXP2_HF is on AND logit_softcap is active, fold log2(e) into the post-tanh multiplier (v_cap) instead of pre-baking it into qk_scale. Pre-baking shifted the tanh knee from x≈c to x≈c/log2(e) and produced numerically wrong softcapped outputs whenever both knobs were enabled.
- flash-attn softmax (fa_softmax_thread): replace the union+memcpy scalar extract pattern with HVX vmux-based per-row accumulators on rowmax/rowsum. Add hvx_vec_get_f16 helper in hvx-base.h. Functional parity, less scalar code, clearer hf/qf16 lane-format contract.
- matmul (hmx_mat_mul_permuted_qk_0_d16a32): pick pipeline vs sequential layout based on whether the chunker actually yields >=2 n-chunks, instead of the static (m>=128 && n>=256) gate. Avoids paying for output double-buffer + worker dispatch when there is no HMX/HVX overlap to gain (e.g. shapes that collapse to one n-chunk).
- tests: add HMX flash-attention coverage over the {mask, ALiBi (max_bias), logit_softcap} cross-product for the prefill path — head_dim 64/128, GQA 4×4, kv=512/nb=64 plus a kv=113/nb=32 non-aligned case.
[Help Wanted]: refactor D matrix computation into separate function for clarity and maintainability
format code
hexagon: looks like -O3 is causing issues with the large code base, switch to -O2 and -flto instead
hexagon: use hex_ prefix for swap_ptr
hexagon: move vtcm_seq_alloc into vtcm-utils.h
More vtcm allocator updates are coming so it makes sense to start the separate hdr for it.
hmx-utils: add hmx_prefix for layout converters
hmx-mm: move main hmx_mm functions to the end, remove unused fwd decls, etc
hmx-mm: remove unused qweight_fetch_task_state_t and minor alignment fixes
hmx-fa: minor alignment fixes
hmx-fa: move hmx_flash_atten into hmx-ops.h
hmx-fa: remove redundant workpool pointer in the hmx_fa_ctx, plus minor alignment updates
hmx-fa: minor alignment and simplifications
hexagon: move FA_EXP_F16 option to hostside CMake file
hmx-fa: use hvx_vec_splat_f16 instead of fp16_to_bits
hmx-fa: add hvx_splat_u16/u8 and use that in the fa instead custom hvx_fill
hmx-fa: some more alignment updates in the core fa function
hmx-fa: keep slopes in vtcm in fp16
Saves malloc/free and removes the need for float -> fp16 downcast on every use.
hexagon: consistent noinline usage (after static)
hex-hmx: consistent use FARF_HIGH to enable debug output
hmx-utils: no need for always_inline attr
hex-hmx: consistent noinline usage (static noinline ...)
hex-hmx: simplify init_col_scales
hexagon: fix editorconfig errors
hmx-mm: minor alignment fixes
Co-authored-by: Max Krasnyansky maxk@qti.qualcomm.com
macOS/iOS:
- macOS Apple Silicon (arm64)
- macOS Apple Silicon (arm64, KleidiAI enabled)
- macOS Intel (x64)
- iOS XCFramework
Linux:
- Ubuntu x64 (CPU)
- Ubuntu arm64 (CPU)
- Ubuntu s390x (CPU)
- Ubuntu x64 (Vulkan)
- Ubuntu arm64 (Vulkan)
- Ubuntu x64 (ROCm 7.2)
- Ubuntu x64 (OpenVINO)
- Ubuntu x64 (SYCL FP32)
- Ubuntu x64 (SYCL FP16)
Android:
Windows:
- Windows x64 (CPU)
- Windows arm64 (CPU)
- Windows x64 (CUDA 12) - CUDA 12.4 DLLs
- Windows x64 (CUDA 13) - CUDA 13.1 DLLs
- Windows x64 (Vulkan)
- Windows x64 (SYCL)
- Windows x64 (HIP)
openEuler: