Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
repos:
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
- id: isort
args: ["--multi-line=7", "--sl", "--profile", "black", "--filter-files"]

- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.0.272"
hooks:
- id: ruff

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*paddle)^.*$
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-case-conflict
- id: check-yaml
exclude: "mkdocs.yml|recipe/meta.yaml"
- id: pretty-format-json
args: [--autofix]
- id: requirements-txt-fixer

- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.0.1
hooks:
- id: forbid-crlf
files: \.md$
- id: remove-crlf
files: \.md$
- id: forbid-tabs
files: \.md$
- id: remove-tabs
files: \.md$

- repo: local
hooks:
- id: clang-format
name: clang-format
description: Format files with ClangFormat
entry: bash .clang_format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$

exclude: |
^jointContribution/
89 changes: 82 additions & 7 deletions setup_ops.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,63 @@
import os
import re
import os.path as osp
import re

import paddle
from paddle.utils.cpp_extension import CppExtension
from paddle.utils.cpp_extension import CUDAExtension
from paddle.utils.cpp_extension import setup

PADDLE_PATH = os.path.dirname(paddle.__file__)
PADDLE_INCLUDE_PATH = os.path.join(PADDLE_PATH, "include")
PADDLE_LIB_PATH = os.path.join(PADDLE_PATH, "libs")
BASE_DIR = "/workspace/wangguan12/xpu"
os.environ["XHPC_PATH"] = BASE_DIR + "/xhpc-ubuntu2004_x86_64"
os.environ["XRE_PATH"] = BASE_DIR + "/xre-Linux-x86_64-5.0.21.22"
os.environ["CLANG_PATH"] = BASE_DIR + "/xtdk-llvm15-ubuntu2004_x86_64"
os.environ["BKCL_PATH"] = BASE_DIR + "/xccl_rdma-ubuntu_x86_64"
# os.environ['XFT_PATH'] = os.environ['XHPC_PATH'] # XFT在XHPC目录下
# os.environ['XBLAS_PATH'] = os.environ['XHPC_PATH'] # XBLAS在XHPC目录下

BKCL_PATH = os.getenv("BKCL_PATH")
if BKCL_PATH is None:
BKCL_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH, "xpu")
BKCL_LIB_PATH = os.path.join(PADDLE_LIB_PATH, "libbkcl.so")
else:
BKCL_INC_PATH = os.path.join(BKCL_PATH, "include")
BKCL_LIB_PATH = os.path.join(BKCL_PATH, "so", "libbkcl.so")

# XFT_PATH = os.getenv("XFT_PATH")
# if XFT_PATH is None:
# XFT_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH, "xft")
# XFT_LIB_PATH = os.path.join(PADDLE_LIB_PATH, "libxft.so")
# else:
# XFT_INC_PATH = os.path.join(XFT_PATH, "include")
# XFT_LIB_PATH = os.path.join(XFT_PATH, "so", "libxft.so")

XRE_PATH = os.getenv("XRE_PATH")
if XRE_PATH is None:
XRE_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH, "xre")
XRE_LIB_PATH = os.path.join(PADDLE_LIB_PATH, "libxpucuda.so")
else:
XRE_INC_PATH = os.path.join(XRE_PATH, "include")
XRE_LIB_PATH = os.path.join(XRE_PATH, "so", "libxpucuda.so")

# XFA_PATH = os.getenv("XFA_PATH")
# if XFA_PATH is None:
# XFA_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH, "xhpc", "xfa")
# XFA_LIB_PATH = os.path.join(PADDLE_LIB_PATH, "libxpu_flash_attention.so")
# else:
# XFA_INC_PATH = os.path.join(XFA_PATH, "include")
# XFA_LIB_PATH = os.path.join(XFA_PATH, "so", "libxpu_flash_attention.so")

# XBLAS_PATH = os.getenv("XBLAS_PATH")
# if XBLAS_PATH is None:
# XBLAS_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH, "xhpc", "xblas")
# XBLAS_LIB_PATH = os.path.join(PADDLE_LIB_PATH, "libxpu_blas.so")
# else:
# XBLAS_INC_PATH = os.path.join(XBLAS_PATH, "include")
# XBLAS_LIB_PATH = os.path.join(XBLAS_PATH, "so", "libxpu_blas.so")


def get_version():
current_dir = osp.dirname(osp.abspath(__file__))
Expand All @@ -18,6 +69,7 @@ def get_version():

raise RuntimeError("Cannot find __version__ in paddle_scatter/__init__.py")


__version__ = get_version()


Expand Down Expand Up @@ -47,12 +99,15 @@ def get_sources():
else:
if item.endswith(".cc"):
cpp_files.append(os.path.join(csrc_dir_path, item))
return csrc_dir_path, cpp_files
return [csrc_dir_path], cpp_files


def get_extensions():
Extension = CppExtension
extra_compile_args = {'cxx': ['-O3']}
extra_objects = []
include_dirs, sources = get_sources()

extra_compile_args = {"cxx": ["-O3"]}
if paddle.device.is_compiled_with_cuda():
set_cuda_archs()
Extension = CUDAExtension
Expand All @@ -61,12 +116,30 @@ def get_extensions():
nvcc_flags += ["-O3"]
nvcc_flags += ["--expt-relaxed-constexpr"]
extra_compile_args["nvcc"] = nvcc_flags
elif paddle.device.is_compiled_with_xpu():
include_dirs += [
XRE_INC_PATH,
# XFT_INC_PATH,
BKCL_LIB_PATH,
# XFA_INC_PATH,
# XBLAS_INC_PATH,
]
extra_objects += [
XRE_LIB_PATH,
# XFT_LIB_PATH,
BKCL_LIB_PATH,
# XFA_LIB_PATH,
# XBLAS_LIB_PATH,
]
extra_compile_args["cxx"] = ["-D_GLIBCXX_USE_CXX11_ABI=1", "-DPADDLE_WITH_XPU"]
else:
raise ("Only CUDA and XPU devices are supported")

src = get_sources()
ext_modules = [
Extension(
sources=src[1],
include_dirs=src[0],
sources=sources,
include_dirs=include_dirs,
extra_objects=extra_objects,
extra_compile_args=extra_compile_args,
)
]
Expand All @@ -80,6 +153,8 @@ def get_extensions():
version=__version__,
author="NKNaN",
url="https://github.com/PFCCLab/paddle_scatter",
description="Paddle extension of scatter and segment operators with min and max reduction methods, originally from https://github.com/rusty1s/pytorch_scatter",
description="Paddle extension of scatter and segment operators \
with min and max reduction methods, \
originally from https://github.com/rusty1s/pytorch_scatter",
ext_modules=get_extensions(),
)