mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Don't call set_device in tl.dot (#1646)
This breaks multiprocess compilation
This commit is contained in:
committed by
GitHub
parent
fb40bf1954
commit
0daee68d71
@@ -1,7 +1,5 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
from collections import namedtuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -198,39 +196,6 @@ def test_jit_noinline() -> None:
|
||||
assert inline_ttir != noinline_ttir
|
||||
|
||||
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])
|
||||
|
||||
|
||||
def compile_fn(config, cc):
|
||||
@triton.jit
|
||||
def kernel_sub(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
|
||||
triton.compile(
|
||||
fn=kernel_sub,
|
||||
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
|
||||
device=0,
|
||||
constants={3: 32},
|
||||
configs=[config],
|
||||
warm_cache_only=True,
|
||||
cc=cc,
|
||||
)
|
||||
|
||||
|
||||
def test_compile_in_subproc() -> None:
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = instance_descriptor(tuple(range(4)), ())
|
||||
|
||||
multiprocessing.set_start_method('spawn')
|
||||
proc = multiprocessing.Process(
|
||||
target=compile_fn,
|
||||
args=(config, cc))
|
||||
proc.start()
|
||||
proc.join()
|
||||
assert proc.exitcode == 0
|
||||
|
||||
|
||||
def test_memory_leak() -> None:
|
||||
@triton.jit
|
||||
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
|
||||
|
||||
83
python/test/unit/runtime/test_subproc.py
Normal file
83
python/test/unit/runtime/test_subproc.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
tmpdir = ".tmp"
|
||||
|
||||
|
||||
def reset_tmp_dir():
|
||||
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
||||
if os.path.exists(tmpdir):
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])
|
||||
|
||||
|
||||
def compile_fn(config, cc):
|
||||
@triton.jit
|
||||
def kernel_sub(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
|
||||
triton.compile(
|
||||
fn=kernel_sub,
|
||||
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
|
||||
device=0,
|
||||
constants={3: 32},
|
||||
configs=[config],
|
||||
warm_cache_only=True,
|
||||
cc=cc,
|
||||
)
|
||||
|
||||
|
||||
def test_compile_in_subproc() -> None:
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = instance_descriptor(tuple(range(4)), ())
|
||||
|
||||
multiprocessing.set_start_method('fork')
|
||||
proc = multiprocessing.Process(
|
||||
target=compile_fn,
|
||||
args=(config, cc))
|
||||
proc.start()
|
||||
proc.join()
|
||||
assert proc.exitcode == 0
|
||||
|
||||
|
||||
def compile_fn_dot(config, cc):
|
||||
@triton.jit
|
||||
def kernel_dot(Z):
|
||||
offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]
|
||||
z = tl.load(Z + offs)
|
||||
z = tl.dot(z, z)
|
||||
tl.store(Z + offs, z)
|
||||
|
||||
triton.compile(
|
||||
fn=kernel_dot,
|
||||
signature={0: "*fp32"},
|
||||
device=0,
|
||||
configs=[config],
|
||||
warm_cache_only=True,
|
||||
cc=cc,
|
||||
)
|
||||
|
||||
|
||||
def test_compile_in_forked_subproc() -> None:
|
||||
reset_tmp_dir()
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = instance_descriptor(tuple(range(1)), ())
|
||||
|
||||
assert multiprocessing.get_start_method() == 'fork'
|
||||
proc = multiprocessing.Process(
|
||||
target=compile_fn_dot,
|
||||
args=(config, cc))
|
||||
proc.start()
|
||||
proc.join()
|
||||
assert proc.exitcode == 0
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations # remove after python 3.11
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Sequence, Tuple, TypeVar
|
||||
|
||||
import triton
|
||||
from . import core as tl
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
@@ -1181,18 +1180,6 @@ def dot(lhs: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
out_dtype: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
if torch.version.hip is None:
|
||||
device = triton.runtime.jit.get_current_device()
|
||||
capability = triton.runtime.jit.get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < 70:
|
||||
assert (
|
||||
not rhs.dtype.is_fp16() and not rhs.dtype.is_fp8()
|
||||
), "Float8 and Float16 types are not supported for compute capability < 70 (use Float32 or above)"
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
assert lhs.dtype == rhs.dtype, "lhs and rhs must have the same dtype!"
|
||||
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
|
||||
|
||||
Reference in New Issue
Block a user