[RUNTIME] Fix memory leak in (#1358)

Fixes a bug that causes Triton to leak 32 bytes on every kernel
invocation.

Also solves https://github.com/pytorch/pytorch/issues/96937
This commit is contained in:
Horace He
2023-03-16 17:52:06 -07:00
committed by GitHub
parent 611a2dc9bf
commit 1d2871d0d1
3 changed files with 48 additions and 0 deletions

View File

@@ -196,3 +196,15 @@ def test_compile_in_subproc() -> None:
proc.start()
proc.join()
assert proc.exitcode == 0
def test_memory_leak() -> None:
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xnumel = 10
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)

View File

@@ -0,0 +1,35 @@
import gc
import tracemalloc
import torch
import triton
import triton.language as tl
def test_memory_leak() -> None:
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xnumel = 10
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)
tracemalloc.start()
try:
inp = torch.randn(10, device='cuda')
out = torch.randn(10, device='cuda')
kernel[(10,)](inp, out, 10, XBLOCK=16)
gc.collect()
begin, _ = tracemalloc.get_traced_memory()
for _ in range(100):
kernel[(10,)](inp, out, 10, XBLOCK=16)
gc.collect()
end, _ = tracemalloc.get_traced_memory()
assert end - begin < 1000
finally:
tracemalloc.stop()