[FRONTEND] Don't call set_device in tl.dot (#1646)

This breaks multiprocess compilation
This commit is contained in:
Natalia Gimelshein
2023-05-10 17:39:27 -07:00
committed by GitHub
parent fb40bf1954
commit 0daee68d71
3 changed files with 83 additions and 48 deletions

View File

@@ -1,7 +1,5 @@
import multiprocessing
import os
import shutil
from collections import namedtuple
import pytest
import torch
@@ -198,39 +196,6 @@ def test_jit_noinline() -> None:
assert inline_ttir != noinline_ttir
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])
def compile_fn(config, cc):
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
triton.compile(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
)
def test_compile_in_subproc() -> None:
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(4)), ())
multiprocessing.set_start_method('spawn')
proc = multiprocessing.Process(
target=compile_fn,
args=(config, cc))
proc.start()
proc.join()
assert proc.exitcode == 0
def test_memory_leak() -> None:
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):

View File

@@ -0,0 +1,83 @@
import multiprocessing
import os
import shutil
from collections import namedtuple
import torch
import triton
import triton.language as tl
tmpdir = ".tmp"
def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])
def compile_fn(config, cc):
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
triton.compile(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
)
def test_compile_in_subproc() -> None:
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(4)), ())
multiprocessing.set_start_method('fork')
proc = multiprocessing.Process(
target=compile_fn,
args=(config, cc))
proc.start()
proc.join()
assert proc.exitcode == 0
def compile_fn_dot(config, cc):
@triton.jit
def kernel_dot(Z):
offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]
z = tl.load(Z + offs)
z = tl.dot(z, z)
tl.store(Z + offs, z)
triton.compile(
fn=kernel_dot,
signature={0: "*fp32"},
device=0,
configs=[config],
warm_cache_only=True,
cc=cc,
)
def test_compile_in_forked_subproc() -> None:
reset_tmp_dir()
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(1)), ())
assert multiprocessing.get_start_method() == 'fork'
proc = multiprocessing.Process(
target=compile_fn_dot,
args=(config, cc))
proc.start()
proc.join()
assert proc.exitcode == 0

View File

@@ -3,7 +3,6 @@ from __future__ import annotations # remove after python 3.11
from functools import wraps
from typing import List, Optional, Sequence, Tuple, TypeVar
import triton
from . import core as tl
from triton._C.libtriton.triton import ir
@@ -1181,18 +1180,6 @@ def dot(lhs: tl.tensor,
allow_tf32: bool,
out_dtype: tl.dtype,
builder: ir.builder) -> tl.tensor:
try:
import torch
except ImportError:
raise ImportError("Triton requires PyTorch to be installed")
if torch.version.hip is None:
device = triton.runtime.jit.get_current_device()
capability = triton.runtime.jit.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
if capability < 70:
assert (
not rhs.dtype.is_fp16() and not rhs.dtype.is_fp8()
), "Float8 and Float16 types are not supported for compute capability < 70 (use Float32 or above)"
assert lhs.type.is_block() and rhs.type.is_block()
assert lhs.dtype == rhs.dtype, "lhs and rhs must have the same dtype!"
assert len(lhs.shape) == 2 and len(rhs.shape) == 2