AMD
基于Triton的ROCm 不同后端实现优化,基本实现vllm后端正常推理,以及pipeline后端中第一步layout用的DocLayout-YOLO
已有完整python vllm和mineru环境直接跳转第五步!!! 其他GPU执行问题可以参考,先prof查看定位找到哪个算子问题,然后triton后端实现即可 测试了一下,基本和MinerU官网效果差不多,用AMD的人也不是很多,就在评论区分享给大家了
1.结果介绍
补充一个200页的PDF python编程书测试一下速度,可以到1.99it/s: Two Step Extraction: 100%|████████████████████████████████████████| 200/200 [01:40<00:00, 1.99it/s]
下面为之前14学术论文测试结果: 7900xtx mineru-gradio --server-name 0.0.0.0 --server-port 7860 --enable-vllm-engine true 速度大概为1.6-1.8s/it,没有仔细测试,简单试了两个文档。第二种矩阵乘法代替原来的dots点乘可以进一步提速到1.3s/it,优化后的主要算子耗时在hipblast(这个没法提升了)和vllm triton后端,各占25%耗时吧,vllm tirion后端这个这个只能等官方优化了。。。。 doclayout-yolo的layout速度从原来的1.6it/s提高到15it/s,注意需要缓存一下输入的pdf尺寸后,triton必须要缓存尺寸没办法。主要是为了保留模型输入输出接口,最小代码改动。 采用-b vlm-vllm-engine模式举个例子
测试结果为优化为5d矩阵乘代替原来的点积结果: 2025-10-05 15:45:12.985 | INFO | mineru.backend.vlm.vlm_analyze:get_model:128 - get vllm-engine predictor cost: 18.45s Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00, 12.20it/s] Processed prompts: 100%|█████████████████████| 14/14 [00:08<00:00, 1.56it/s, est. speed input: 2174.18 toks/s, output: 791.87 toks/s] Adding requests: 100%|█████████████████████████████████████████████████████████████████████████████| 278/278 [00:00<00:00, 323.03it/s] Processed prompts: 100%|██████████████████| 278/278 [00:07<00:00, 37.63it/s, est. speed input: 5264.66 toks/s, output: 2733.31 toks/s]
mineru-gradio --server-name 0.0.0.0 --server-port 7860 --enable-vllm-engine true测试: 2025-10-05 15:46:55.953 | WARNING | mineru.cli.common:convert_pdf_bytes_to_bytes_by_pypdfium2:54 - end_page_id is out of range, use pdf_docs length Two Step Extraction: 100%|████████████████████████████████████████████████████████████████████████████| 14/14 [00:18<00:00, 1.30s/it]
2.原因介绍
AMD RDNA使用vllm后端有严重的性能问题,原因是因为vllm的qwen2_vl.py中有一个算子在rocm kernel上没有对应的实现,导致性能出现严重的卷积计算回退,一次执行花了12s,。。。。。。。。一言难尽。即MIOpen 库中缺少模型中特定 Conv3d(bfloat16) 的优化内核。 DocLayout-YOLO的g2l_crm.py空洞卷积也是这个问题,专业的CDNA MI210也没解决这个问题 正好一起处理了。
3.环境介绍
System: Ubuntu 24.04.3 Kernel: Linux 6.14.0-33-generic ROCm version: 7.0.1 python环境: python 3.12 pytorch-triton-rocm 3.5.0+gitbbb06c03 torch 2.10.0.dev20251001+rocm7.0 torchvision 0.25.0.dev20251003+rocm7.0 vllm 0.11.0rc2.dev198+g736fbf4c8.rocm701 不同版本无所谓,处理方法是一样的。
4.前置环境安装
uv venv --python python3.12
source .venv/bin/activate
uv pip install --pre torch torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple/ --extra-index-url https://download.pytorch.org/whl/nightly/rocm7.0
uv pip install pip
# 避免覆盖我们本地的pytorch,改用pip而没有继续使用uv pip
pip install -U "mineru[core]" -i https://pypi.mirrors.ustc.edu.cn/simple/
#手动安装aiter,vllm,amd-smi等,自行找一个位置clone,然后进入该目录吧
git clone --recursive https://github.com/ROCm/aiter.git
cd aiter
git submodule sync; git submodule update --init --recursive
python setup.py develop
cd ..
git clone https://github.com/vllm-project/vllm.git
cd vllm/
cp -r /opt/rocm/share/amd_smi ~/Pytorch/vllm/
pip install amd_smi/
pip install --upgrade numba \
scipy \
huggingface-hub[cli,hf_transfer] \
setuptools_scm
pip install -r requirements/rocm.txt
export PYTORCH_ROCM_ARCH="gfx1100" #根据自己的GPU架构 rocminfo | grep gfx
python setup.py develop
5.vllm中关键triton算子添加
这里我给出两种解决方法,第一种解决方法就是前面提到的优化到1.5到1.8s/it,第二种方法有手动优化算子到矩阵乘法,7900xtx肯定适用,大概1.3s/it,其他AMD GPU相对方案一也有提速,但是不一定是最佳速度实现,里面的手动部分可能需要微调。
注意pip把triton 后端的flash_attn卸载了,搞了半天各种尝试还是报错,问题比较大,直接不用就行了
#定位自己vllm位置XXX
pip show vllm
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Callable, Literal, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from .qwen2_vl_vision_kernels import triton_conv3d_patchify
方案1 2.1qwen2_vl.py文件498行class Qwen2VisionPatchEmbed(nn.Module),PS.就是这玩意AMD没有现成的内核算子导致回退
class Qwen2VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.embed_dim = embed_dim
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(in_channels,
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x_reshaped = x.view(L, -1, self.temporal_patch_size, self.patch_size,
self.patch_size)
# Call your custom Triton kernel instead of self.proj
x_out = triton_conv3d_patchify(x_reshaped, self.proj.weight)
# The output of our kernel is already the correct shape [L, embed_dim]
return x_out
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _conv3d_patchify_kernel(
# Pointers to tensors
X, W, Y,
# Tensor dimensions
N, C_in, D_in, H_in, W_in,
C_out, KD, KH, KW,
# Stride and padding for memory access
stride_xn, stride_xc, stride_xd, stride_xh, stride_xw,
stride_wn, stride_wc, stride_wd, stride_wh, stride_ww,
stride_yn, stride_yc,
# Triton-specific metaparameters
BLOCK_SIZE: tl.constexpr,
):
"""
Triton kernel for a non-overlapping 3D patching convolution.
Each kernel instance computes one output value for one patch.
"""
# Get the program IDs for the N (patch) and C_out (output channel) dimensions
pid_n = tl.program_id(0) # The index of the patch we are processing
pid_cout = tl.program_id(1) # The index of the output channel we are computing
# --- Calculate memory pointers ---
# Pointer to the start of the current input patch
x_ptr = X + (pid_n * stride_xn)
# Pointer to the start of the current filter (weight)
w_ptr = W + (pid_cout * stride_wn)
# Pointer to where the output will be stored
y_ptr = Y + (pid_n * stride_yn + pid_cout * stride_yc)
# --- Perform the convolution (element-wise product and sum) ---
# This is a dot product between the flattened patch and the flattened filter.
accumulator = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
# Iterate over the elements of the patch/filter
for c_offset in range(0, C_in):
for d_offset in range(0, KD):
for h_offset in range(0, KH):
# Unrolled loop for the innermost dimension (width) for performance
for w_offset in range(0, KW, BLOCK_SIZE):
# Create masks to handle cases where KW is not a multiple of BLOCK_SIZE
w_range = w_offset + tl.arange(0, BLOCK_SIZE)
w_mask = w_range < KW
# Calculate offsets to load data
patch_offset = (c_offset * stride_xc + d_offset * stride_xd +
h_offset * stride_xh + w_range * stride_xw)
filter_offset = (c_offset * stride_wc + d_offset * stride_wd +
h_offset * stride_wh + w_range * stride_ww)
# Load patch and filter data, applying masks
patch_vals = tl.load(x_ptr + patch_offset, mask=w_mask, other=0.0)
filter_vals = tl.load(w_ptr + filter_offset, mask=w_mask, other=0.0)
# Multiply and accumulate
accumulator += patch_vals.to(tl.float32) * filter_vals.to(tl.float32)
# Sum the accumulator block and store the single output value
output_val = tl.sum(accumulator, axis=0)
tl.store(y_ptr, output_val)
def triton_conv3d_patchify(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Python wrapper for the 3D patching convolution Triton kernel.
"""
# Get tensor dimensions
N, C_in, D_in, H_in, W_in = x.shape
C_out, _, KD, KH, KW = weight.shape
# Create the output tensor
# The output of this specific conv is (N, C_out, 1, 1, 1), which we squeeze
Y = torch.empty((N, C_out), dtype=x.dtype, device=x.device)
# Define the grid for launching the Triton kernel
# Each kernel instance handles one patch (N) for one output channel (C_out)
grid = (N, C_out)
# Launch the kernel
# We pass all strides to make the kernel flexible
_conv3d_patchify_kernel[grid](
x, weight, Y,
N, C_in, D_in, H_in, W_in,
C_out, KD, KH, KW,
x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
weight.stride(0), weight.stride(1), weight.stride(2), weight.stride(3), weight.stride(4),
Y.stride(0), Y.stride(1),
BLOCK_SIZE=16, # A reasonable default, can be tuned
)
return Y
方案2 2.2qwen2_vl.py文件498行class Qwen2VisionPatchEmbed(nn.Module)函数,PS.就是这玩意AMD没有现成的内核算子导致回退,这里我们直接5D张量一步到位,改为矩阵乘法
class Qwen2VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.embed_dim = embed_dim
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(in_channels,
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x_reshaped_5d = x.view(L, -1, self.temporal_patch_size, self.patch_size,
self.patch_size)
return triton_conv3d_patchify(x_reshaped_5d, self.proj.weight)
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _conv_gemm_kernel(
A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K) & (offs_n[None, :] < N), other=0.0)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
offs_k += BLOCK_K
c = accumulator.to(C.dtype.element_ty)
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def triton_conv3d_patchify(x_5d: torch.Tensor, weight_5d: torch.Tensor) -> torch.Tensor:
N_patches, _, _, _, _ = x_5d.shape
C_out, _, _, _, _ = weight_5d.shape
A = x_5d.view(N_patches, -1)
B = weight_5d.view(C_out, -1).transpose(0, 1).contiguous()
M, K = A.shape
_K, N = B.shape
assert K == _K
C = torch.empty((M, N), device=A.device, dtype=A.dtype)
# --- 针对7900xtx的手动调优配置,其他GPU的最优组合可能需要自行寻找,AMD的autotune效果就是没有效果 ---
best_config = {
'BLOCK_M': 128,
'BLOCK_N': 128,
'BLOCK_K': 32,
}
num_stages = 4
num_warps = 8
grid = (triton.cdiv(M, best_config['BLOCK_M']),
triton.cdiv(N, best_config['BLOCK_N']))
_conv_gemm_kernel[grid](
A, B, C,
M, N, K,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
**best_config,
num_stages=num_stages,
num_warps=num_warps
)
return C
4.关闭终端后再次使用mineru-gradio会报一个Lora错误,修改代码跳过它
pip show mineru_vl_utils
打开该文件XXX/mineru_vl_utils/vlm_client/vllm_async_engine_client.py修改第58行self.tokenizer = vllm_async_llm.tokenizer.get_lora_tokenizer()为:
try:
self.tokenizer = vllm_async_llm.tokenizer.get_lora_tokenizer()
except AttributeError:
# 如果没有 get_lora_tokenizer 方法,直接使用原始 tokenizer
self.tokenizer = vllm_async_llm.tokenizer
最后整两个环境变量后愉快玩耍即可
export MINERU_MODEL_SOURCE=modelscope
export TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
6.vllm后端已经没有问题,下面是pipeline 中layout用的doclayout-yolo模型空洞卷积问题
我在 DocLayout-YOLO 下做了一个回答,因此 pipeline 的空洞卷积问题不在这里赘述,直接点击链接查看即可。
查看自己doclayout-yolo安装位置如下,然后进入修改链接中回复介绍的文件即可
pip show doclayout-yolo