跳转至

AMD

基于Triton的ROCm 不同后端实现优化,基本实现vllm后端正常推理,以及pipeline后端中第一步layout用的DocLayout-YOLO

已有完整python vllm和mineru环境直接跳转第五步!!! 其他GPU执行问题可以参考,先prof查看定位找到哪个算子问题,然后triton后端实现即可 测试了一下,基本和MinerU官网效果差不多,用AMD的人也不是很多,就在评论区分享给大家了

1.结果介绍

补充一个200页的PDF python编程书测试一下速度,可以到1.99it/s: Two Step Extraction: 100%|████████████████████████████████████████| 200/200 [01:40<00:00, 1.99it/s]

下面为之前14学术论文测试结果: 7900xtx mineru-gradio --server-name 0.0.0.0 --server-port 7860 --enable-vllm-engine true 速度大概为1.6-1.8s/it,没有仔细测试,简单试了两个文档。第二种矩阵乘法代替原来的dots点乘可以进一步提速到1.3s/it,优化后的主要算子耗时在hipblast(这个没法提升了)和vllm triton后端,各占25%耗时吧,vllm tirion后端这个这个只能等官方优化了。。。。 doclayout-yolo的layout速度从原来的1.6it/s提高到15it/s,注意需要缓存一下输入的pdf尺寸后,triton必须要缓存尺寸没办法。主要是为了保留模型输入输出接口,最小代码改动。 采用-b vlm-vllm-engine模式举个例子


测试结果为优化为5d矩阵乘代替原来的点积结果: 2025-10-05 15:45:12.985 | INFO | mineru.backend.vlm.vlm_analyze:get_model:128 - get vllm-engine predictor cost: 18.45s Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00, 12.20it/s] Processed prompts: 100%|█████████████████████| 14/14 [00:08<00:00, 1.56it/s, est. speed input: 2174.18 toks/s, output: 791.87 toks/s] Adding requests: 100%|█████████████████████████████████████████████████████████████████████████████| 278/278 [00:00<00:00, 323.03it/s] Processed prompts: 100%|██████████████████| 278/278 [00:07<00:00, 37.63it/s, est. speed input: 5264.66 toks/s, output: 2733.31 toks/s]

mineru-gradio --server-name 0.0.0.0 --server-port 7860 --enable-vllm-engine true测试: 2025-10-05 15:46:55.953 | WARNING | mineru.cli.common:convert_pdf_bytes_to_bytes_by_pypdfium2:54 - end_page_id is out of range, use pdf_docs length Two Step Extraction: 100%|████████████████████████████████████████████████████████████████████████████| 14/14 [00:18<00:00, 1.30s/it]


2.原因介绍

AMD RDNA使用vllm后端有严重的性能问题,原因是因为vllm的qwen2_vl.py中有一个算子在rocm kernel上没有对应的实现,导致性能出现严重的卷积计算回退,一次执行花了12s,。。。。。。。。一言难尽。即MIOpen 库中缺少模型中特定 Conv3d(bfloat16) 的优化内核。 DocLayout-YOLO的g2l_crm.py空洞卷积也是这个问题,专业的CDNA MI210也没解决这个问题 正好一起处理了。


3.环境介绍

System: Ubuntu 24.04.3 Kernel: Linux 6.14.0-33-generic ROCm version: 7.0.1 python环境: python 3.12 pytorch-triton-rocm 3.5.0+gitbbb06c03 torch 2.10.0.dev20251001+rocm7.0 torchvision 0.25.0.dev20251003+rocm7.0 vllm 0.11.0rc2.dev198+g736fbf4c8.rocm701 不同版本无所谓,处理方法是一样的。


4.前置环境安装

uv venv --python python3.12
source .venv/bin/activate
uv pip install --pre torch torchvision   -i https://pypi.tuna.tsinghua.edu.cn/simple/   --extra-index-url https://download.pytorch.org/whl/nightly/rocm7.0
uv pip install pip
# 避免覆盖我们本地的pytorch,改用pip而没有继续使用uv pip
pip install -U "mineru[core]" -i https://pypi.mirrors.ustc.edu.cn/simple/
vllm 安装参考官方手册Vllm
#手动安装aiter,vllm,amd-smi等,自行找一个位置clone,然后进入该目录吧
git clone --recursive https://github.com/ROCm/aiter.git
cd aiter
git submodule sync; git submodule update --init --recursive
python setup.py develop
cd ..
git clone https://github.com/vllm-project/vllm.git
cd vllm/
cp -r /opt/rocm/share/amd_smi ~/Pytorch/vllm/
pip install amd_smi/
pip install --upgrade numba \
    scipy \
    huggingface-hub[cli,hf_transfer] \
    setuptools_scm
pip install -r requirements/rocm.txt
export PYTORCH_ROCM_ARCH="gfx1100"   #根据自己的GPU架构 rocminfo | grep gfx
python setup.py develop

5.vllm中关键triton算子添加

这里我给出两种解决方法,第一种解决方法就是前面提到的优化到1.5到1.8s/it,第二种方法有手动优化算子到矩阵乘法,7900xtx肯定适用,大概1.3s/it,其他AMD GPU相对方案一也有提速,但是不一定是最佳速度实现,里面的手动部分可能需要微调。

注意pip把triton 后端的flash_attn卸载了,搞了半天各种尝试还是报错,问题比较大,直接不用就行了

#定位自己vllm位置XXX
pip show vllm
关键更改 XXX/vllm/model_executor/models/qwen2_vl.py文件: 1.qwen2_vl.py文件33行下增加from .qwen2_vl_vision_kernels import triton_conv3d_patchify
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Callable, Literal, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from .qwen2_vl_vision_kernels import triton_conv3d_patchify
接下来分为方案一(2.1和3.1)和方案二(2.2和3.2),选取一种实现即可

方案1 2.1qwen2_vl.py文件498行class Qwen2VisionPatchEmbed(nn.Module),PS.就是这玩意AMD没有现成的内核算子导致回退

class Qwen2VisionPatchEmbed(nn.Module):

    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
        in_channels: int = 3,
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

        kernel_size = (temporal_patch_size, patch_size, patch_size)
        self.proj = nn.Conv3d(in_channels,
                              embed_dim,
                              kernel_size=kernel_size,
                              stride=kernel_size,
                              bias=False)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
        x_reshaped = x.view(L, -1, self.temporal_patch_size, self.patch_size,
                            self.patch_size)

        # Call your custom Triton kernel instead of self.proj
        x_out = triton_conv3d_patchify(x_reshaped, self.proj.weight)

        # The output of our kernel is already the correct shape [L, embed_dim]
        return x_out
3.1XXX/vllm/model_executor/models/目录下创建qwen2_vl_vision_kernels.py文件,用triton实现
import torch
from vllm.triton_utils import tl, triton

@triton.jit
def _conv3d_patchify_kernel(
    # Pointers to tensors
    X, W, Y,
    # Tensor dimensions
    N, C_in, D_in, H_in, W_in,
    C_out, KD, KH, KW,
    # Stride and padding for memory access
    stride_xn, stride_xc, stride_xd, stride_xh, stride_xw,
    stride_wn, stride_wc, stride_wd, stride_wh, stride_ww,
    stride_yn, stride_yc,
    # Triton-specific metaparameters
    BLOCK_SIZE: tl.constexpr,
):
    """
    Triton kernel for a non-overlapping 3D patching convolution.
    Each kernel instance computes one output value for one patch.
    """
    # Get the program IDs for the N (patch) and C_out (output channel) dimensions
    pid_n = tl.program_id(0)  # The index of the patch we are processing
    pid_cout = tl.program_id(1) # The index of the output channel we are computing

    # --- Calculate memory pointers ---
    # Pointer to the start of the current input patch
    x_ptr = X + (pid_n * stride_xn)
    # Pointer to the start of the current filter (weight)
    w_ptr = W + (pid_cout * stride_wn)
    # Pointer to where the output will be stored
    y_ptr = Y + (pid_n * stride_yn + pid_cout * stride_yc)

    # --- Perform the convolution (element-wise product and sum) ---
    # This is a dot product between the flattened patch and the flattened filter.
    accumulator = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)

    # Iterate over the elements of the patch/filter
    for c_offset in range(0, C_in):
        for d_offset in range(0, KD):
            for h_offset in range(0, KH):
                # Unrolled loop for the innermost dimension (width) for performance
                for w_offset in range(0, KW, BLOCK_SIZE):
                    # Create masks to handle cases where KW is not a multiple of BLOCK_SIZE
                    w_range = w_offset + tl.arange(0, BLOCK_SIZE)
                    w_mask = w_range < KW

                    # Calculate offsets to load data
                    patch_offset = (c_offset * stride_xc + d_offset * stride_xd +
                                    h_offset * stride_xh + w_range * stride_xw)
                    filter_offset = (c_offset * stride_wc + d_offset * stride_wd +
                                     h_offset * stride_wh + w_range * stride_ww)

                    # Load patch and filter data, applying masks
                    patch_vals = tl.load(x_ptr + patch_offset, mask=w_mask, other=0.0)
                    filter_vals = tl.load(w_ptr + filter_offset, mask=w_mask, other=0.0)

                    # Multiply and accumulate
                    accumulator += patch_vals.to(tl.float32) * filter_vals.to(tl.float32)

    # Sum the accumulator block and store the single output value
    output_val = tl.sum(accumulator, axis=0)
    tl.store(y_ptr, output_val)


def triton_conv3d_patchify(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
    """
    Python wrapper for the 3D patching convolution Triton kernel.
    """
    # Get tensor dimensions
    N, C_in, D_in, H_in, W_in = x.shape
    C_out, _, KD, KH, KW = weight.shape

    # Create the output tensor
    # The output of this specific conv is (N, C_out, 1, 1, 1), which we squeeze
    Y = torch.empty((N, C_out), dtype=x.dtype, device=x.device)

    # Define the grid for launching the Triton kernel
    # Each kernel instance handles one patch (N) for one output channel (C_out)
    grid = (N, C_out)

    # Launch the kernel
    # We pass all strides to make the kernel flexible
    _conv3d_patchify_kernel[grid](
        x, weight, Y,
        N, C_in, D_in, H_in, W_in,
        C_out, KD, KH, KW,
        x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
        weight.stride(0), weight.stride(1), weight.stride(2), weight.stride(3), weight.stride(4),
        Y.stride(0), Y.stride(1),
        BLOCK_SIZE=16, # A reasonable default, can be tuned
    )

    return Y

方案2 2.2qwen2_vl.py文件498行class Qwen2VisionPatchEmbed(nn.Module)函数,PS.就是这玩意AMD没有现成的内核算子导致回退,这里我们直接5D张量一步到位,改为矩阵乘法

class Qwen2VisionPatchEmbed(nn.Module):

    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
        in_channels: int = 3,
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

        kernel_size = (temporal_patch_size, patch_size, patch_size)

        self.proj = nn.Conv3d(in_channels,
                              embed_dim,
                              kernel_size=kernel_size,
                              stride=kernel_size,
                              bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
        x_reshaped_5d = x.view(L, -1, self.temporal_patch_size, self.patch_size,
                               self.patch_size)

        return triton_conv3d_patchify(x_reshaped_5d, self.proj.weight)
3.2XXX/vllm/model_executor/models/目录下创建qwen2_vl_vision_kernels.py文件,用triton实现
import torch
from vllm.triton_utils import tl, triton

@triton.jit
def _conv_gemm_kernel(
    A, B, C, M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)
    a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_K):
        a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)
        b = tl.load(b_ptrs, mask=(offs_k[:, None] < K) & (offs_n[None, :] < N), other=0.0)
        accumulator += tl.dot(a, b)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
        offs_k += BLOCK_K
    c = accumulator.to(C.dtype.element_ty)
    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

def triton_conv3d_patchify(x_5d: torch.Tensor, weight_5d: torch.Tensor) -> torch.Tensor:
    N_patches, _, _, _, _ = x_5d.shape
    C_out, _, _, _, _ = weight_5d.shape
    A = x_5d.view(N_patches, -1)
    B = weight_5d.view(C_out, -1).transpose(0, 1).contiguous()
    M, K = A.shape
    _K, N = B.shape
    assert K == _K
    C = torch.empty((M, N), device=A.device, dtype=A.dtype)

    # --- 针对7900xtx的手动调优配置,其他GPU的最优组合可能需要自行寻找,AMD的autotune效果就是没有效果 ---
    best_config = {
        'BLOCK_M': 128,
        'BLOCK_N': 128,
        'BLOCK_K': 32,
    }
    num_stages = 4
    num_warps = 8

    grid = (triton.cdiv(M, best_config['BLOCK_M']),
            triton.cdiv(N, best_config['BLOCK_N']))

    _conv_gemm_kernel[grid](
        A, B, C,
        M, N, K,
        A.stride(0), A.stride(1),
        B.stride(0), B.stride(1),
        C.stride(0), C.stride(1),
        **best_config,
        num_stages=num_stages,
        num_warps=num_warps
    )

    return C

4.关闭终端后再次使用mineru-gradio会报一个Lora错误,修改代码跳过它

pip show mineru_vl_utils

打开该文件XXX/mineru_vl_utils/vlm_client/vllm_async_engine_client.py修改第58行self.tokenizer = vllm_async_llm.tokenizer.get_lora_tokenizer()为:

        try:
            self.tokenizer = vllm_async_llm.tokenizer.get_lora_tokenizer()
        except AttributeError:
            # 如果没有 get_lora_tokenizer 方法,直接使用原始 tokenizer
            self.tokenizer = vllm_async_llm.tokenizer

最后整两个环境变量后愉快玩耍即可

export MINERU_MODEL_SOURCE=modelscope
export TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1

6.vllm后端已经没有问题,下面是pipeline 中layout用的doclayout-yolo模型空洞卷积问题

我在 DocLayout-YOLO 下做了一个回答,因此 pipeline 的空洞卷积问题不在这里赘述,直接点击链接查看即可。

查看自己doclayout-yolo安装位置如下,然后进入修改链接中回复介绍的文件即可

pip show doclayout-yolo