Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2024-09-26 15:03:00 -07:00
19 changed files with 223 additions and 133 deletions

View File

@@ -9,6 +9,7 @@
::: tinygrad.Tensor.full_like
::: tinygrad.Tensor.zeros_like
::: tinygrad.Tensor.ones_like
::: tinygrad.Tensor.from_blob
## Creation (random)

View File

@@ -5,6 +5,7 @@ export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
export BEAM=4
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"
export BENCHMARK=10 DEBUG=2

View File

@@ -5,8 +5,9 @@ export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
export BEAM=4
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"
export WANDB=1
python3 examples/mlperf/model_train.py
RUNMLPERF=1 python3 examples/mlperf/model_train.py

View File

@@ -6,6 +6,7 @@ export SUBMISSION_PLATFORM="tinybox_green"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
export BEAM=4
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"
# pip install -e ".[mlperf]"

View File

@@ -5,6 +5,7 @@ export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
export BEAM=4
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"
export BENCHMARK=10 DEBUG=2

View File

@@ -5,8 +5,9 @@ export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
export BEAM=4
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"
export WANDB=1
python3 examples/mlperf/model_train.py
RUNMLPERF=1 python3 examples/mlperf/model_train.py

View File

@@ -6,6 +6,7 @@ export SUBMISSION_PLATFORM="tinybox_red"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
export BEAM=4
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"
# pip install -e ".[mlperf]"

View File

@@ -0,0 +1,87 @@
import ctypes, array
from hexdump import hexdump
from tinygrad.runtime.ops_gpu import GPUDevice
from tinygrad.helpers import getenv, to_mv, mv_address
from tinygrad.dtype import dtypes
from tinygrad import Tensor, TinyJit
from tinygrad.runtime.autogen import opencl as cl
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
# create raw opencl buffer.
gdev = GPUDevice()
cl_buf = cl.clCreateBuffer(gdev.context, cl.CL_MEM_READ_WRITE, 0x100, None, status := ctypes.c_int32())
assert status.value == 0
# fill it with something for fun
data = memoryview(array.array('I', [i for i in range(64)]))
cl.clEnqueueWriteBuffer(gdev.queue, cl_buf, False, 0, 0x100, mv_address(data), 0, None, None)
cl.clFinish(gdev.queue) # wait writes to complete
# get raw gpu pointer from opencl buffer.
## get buf desc
hexdump(to_mv(ctypes.addressof(cl_buf), 0x40))
cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_buf), 8).cast('Q')[0]
## get buf device ptr
hexdump(to_mv(cl_buf_desc_ptr, 0x100))
rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20] # offset 0xA0 is a raw gpu pointer.
# create QCOM tensor with the externally managed buffer
x = Tensor.from_blob(rawbuf_ptr, (8, 8), dtype=dtypes.int, device='QCOM')
y = (x + 1).numpy()
print(y)
# all calculations are done, save to free the object
cl.clReleaseMemObject(cl_buf)
# all together with jit
@TinyJit
def calc(x): return x + 2
for i in range(4):
cl_buf = cl.clCreateBuffer(gdev.context, cl.CL_MEM_READ_WRITE, 2*2*4, None, status := ctypes.c_int32())
assert status.value == 0
data = memoryview(array.array('I', [x+i for x in range(2*2)]))
cl.clEnqueueWriteBuffer(gdev.queue, cl_buf, False, 0, 2*2*4, mv_address(data), 0, None, None)
cl.clFinish(gdev.queue) # wait writes to complete
cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_buf), 8).cast('Q')[0]
rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20]
y = calc(x = Tensor.from_blob(rawbuf_ptr, (2, 2), dtype=dtypes.int, device='QCOM')).numpy()
print(f'jit {i}\n', y)
# all calculations are done, save to free the object
cl.clReleaseMemObject(cl_buf)
# now images!
h, w = 128, 128
cl_img = cl.clCreateImage2D(gdev.context, cl.CL_MEM_READ_WRITE, cl.cl_image_format(cl.CL_RGBA, cl.CL_FLOAT), w, h, 0, None, status := ctypes.c_int32())
assert status.value == 0
# fill it with something for fun
data = memoryview(array.array('f', [i for i in range(h*w*4)]))
cl.clEnqueueWriteImage(gdev.queue, cl_img, False, (ctypes.c_size_t * 3)(0,0,0), (ctypes.c_size_t * 3)(w,h,1), 0, 0, mv_address(data), 0, None, None)
cl.clFinish(gdev.queue) # wait writes to complete
# get raw gpu pointer from opencl buffer.
## get buf desc
hexdump(to_mv(ctypes.addressof(cl_img), 0x40))
cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_img), 8).cast('Q')[0]
## get buf device ptr
hexdump(to_mv(cl_buf_desc_ptr, 0x100))
rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20] # offset 0xA0 is a raw gpu pointer.
# create QCOM tensor with the externally managed buffer
# dtypes.imageh = cl.cl_image_format(cl.CL_RGBA, cl.CL_HALF_FLOAT)
# dtypes.imagef = cl.cl_image_format(cl.CL_RGBA, cl.CL_FLOAT)
x = Tensor.from_blob(rawbuf_ptr, (h*w*4,), dtype=dtypes.imagef((h,w)), device='QCOM')
y = (x + 1).numpy()
print(y)
# all calculations are done, save to free the object
cl.clReleaseMemObject(cl_img)

View File

@@ -42,10 +42,9 @@ class TestLinearizerDumb(unittest.TestCase):
prg = k.to_program()
print(prg.src)
Device[Device.DEFAULT].compiler.compile_cached(prg.src)
with self.assertRaises(AssertionError):
gate_count = len([x for x in prg.src.splitlines() if "if" in x])
assert gate_count == 1, f"must have only one gate {gate_count} != 1"
assert len([u for u in k.uops if u.op is UOps.IF]) == 1, "must have a single IF"
gate_count = len([x for x in prg.src.splitlines() if "if" in x])
assert gate_count == 1, f"must have only one gate {gate_count} != 1"
assert len([u for u in k.uops if u.op is UOps.IF]) == 1, "must have a single IF"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
def test_max_simplify_and_cancel(self):
@@ -95,7 +94,7 @@ class TestLinearizerDumb(unittest.TestCase):
prg = k.to_program()
print(prg.src)
if_uops = [u for u in k.uops if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assertIn(len(if_uops), {1,3})
conditions = if_uops[0].src[0].sparents
self.assertLessEqual(len(conditions), 9)

View File

@@ -1043,7 +1043,7 @@ class TestLinearizerFailures(unittest.TestCase):
k = helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
assert k is not None
ifs = [u for u in k.uops if u.op is UOps.IF]
self.assertEqual(len(ifs), 1)
self.assertEqual(len(ifs), 4)
#for st in k.uops.sink.src: self.assertEqual(len(st.src), 4)
self.assertLessEqual(len(ifs[0].src[0].sparents), 17)

View File

@@ -1825,33 +1825,21 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="linear", align_corners=True),
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True))
def test_interpolate_nearest(self):
for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]:
def test_interpolate_nearest(self, mode="nearest"):
for in_sz, out_sz in [((13,),(9,)), ((9,),(13,))]:
helper_test_op([(2,3)+in_sz],
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest"),
lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest"))
for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]:
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode=mode),
lambda x: Tensor.interpolate(x, size=out_sz, mode=mode))
for in_sz, out_sz in [((13,10),(9,11)), ((13,9),(11,10)), ((9,11),(10,13))]:
helper_test_op([(2,3)+in_sz],
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest"),
lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest"))
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode=mode),
lambda x: Tensor.interpolate(x, size=out_sz, mode=mode))
for in_sz, out_sz in [((5,2,8),(3,6,4))]:
helper_test_op([(2,3)+in_sz],
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest"),
lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest"))
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode=mode),
lambda x: Tensor.interpolate(x, size=out_sz, mode=mode))
def test_interpolate_nearest_exact(self):
for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]:
helper_test_op([(2,3)+in_sz],
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest-exact"),
lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest-exact"))
for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]:
helper_test_op([(2,3)+in_sz],
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest-exact"),
lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest-exact"))
for in_sz, out_sz in [((5,2,8),(3,6,4))]:
helper_test_op([(2,3)+in_sz],
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest-exact"),
lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest-exact"))
def test_interpolate_nearest_exact(self): self.test_interpolate_nearest("nearest-exact")
def test_interpolate_bilinear(self):
for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]:

View File

@@ -1,10 +1,10 @@
import subprocess
import numpy as np
import torch
import unittest, copy, mmap, random, math
import unittest, copy, mmap, random, math, array
from tinygrad import Tensor, Device, dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.helpers import getenv, temp, CI, _METADATA
from tinygrad.helpers import getenv, temp, CI, _METADATA, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat
from test.helpers import is_dtype_supported
@@ -330,6 +330,17 @@ class TestTinygrad(unittest.TestCase):
assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 # check if ndarray correctly casts to Tensor dtype
assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 # check that it works for something else
def test_tensor_from_blob(self):
x = memoryview(bytearray(16)).cast('I')
t = Tensor.from_blob(mv_address(x), (4,), dtype=dtypes.int, device="CLANG")
z = (t+1)
np.testing.assert_equal(z.numpy(), [1, 1, 1, 1])
x[:] = array.array('I', [0, 1, 2, 3])
z = (t+1)
np.testing.assert_equal(z.numpy(), [1, 2, 3, 4])
def test_tensor_list_dtype(self):
for arr in ([1], [[[1]]], [[1,1],[1,1]], [[[1,1],[1,1]],[[1,1],[1,1]]]):
assert Tensor(arr).dtype == dtypes.default_int

View File

@@ -784,6 +784,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
# prefer uops that are loop children
else:
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss])
if u.op is UOps.IF and len(u.src) == 1: priority += 10000000 # if penalty
return priority
priorities:Dict[UOp, int] = {u:get_priority(u) for u in children}

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import multiprocessing, decimal, statistics, random
from dataclasses import dataclass
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
@@ -48,6 +48,7 @@ class BufferOptions:
cpu_access: bool = False
host: bool = False
nolru: bool = False
external_ptr: Optional[int] = None
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
@@ -75,9 +76,11 @@ class Buffer:
def ref(self, cnt): self.base._lb_refcount += cnt
def is_allocated(self) -> bool: return hasattr(self, '_buf')
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
def allocate(self, opaque=None) -> Buffer:
def allocate(self, opaque=None, external_ptr=None) -> Buffer:
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
self.allocator = Device[self.device].allocator
if external_ptr is not None:
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferOptions(external_ptr=external_ptr)
if self._base is not None:
self._base.ensure_allocated()
assert hasattr(self.allocator, "offset"), "offset function required for view"
@@ -99,7 +102,7 @@ class Buffer:
def nbytes(self): return self.size*self.dtype.itemsize
def __del__(self):
if not hasattr(self, '_buf'): return
if self._base is None:
if self._base is None and (self.options is None or self.options.external_ptr is None):
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
self.allocator.free(self._buf, self.nbytes, self.options)
def __repr__(self):
@@ -162,7 +165,8 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
else: super().free(opaque, size, options)
class _MallocAllocator(LRUAllocator):
def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
def _alloc(self, size:int, options:BufferOptions):
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else (ctypes.c_uint8 * size)()
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))

View File

@@ -594,6 +594,7 @@ spec = PatternMatcher([
# STORE takes a <buf, idx, val, gate?>
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat())), lambda: True),
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(UOps.IF))), lambda: True),
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
(UPat(UOps.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE),
@@ -614,7 +615,8 @@ spec = PatternMatcher([
(UPat(UOps.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(UOps.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# if has a <gate, barrier>
# if has a <gate, barrier?>
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True),
(UPat(UOps.ENDIF, dtype=dtypes.void, src=(UPat(UOps.IF),)), lambda: True),

View File

@@ -32,6 +32,12 @@ asm_for_op: Dict[Op, Callable] = {
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
}
def load_store_ptr_arithmetic(x:UOp, buf:UOp, alu:Optional[UOp]=None, const:Optional[UOp]=None) -> UOp:
src = list(x.src)
src[0] = buf.cast(dtypes.int64) if alu is None else (buf.cast(dtypes.int64) + alu.cast(dtypes.int64)*buf.dtype.itemsize)
src[1] = UOp.const(dtypes.int64, 0 if const is None else const.arg*buf.dtype.itemsize)
return x.replace(src=tuple(src))
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
ptx_matcher = constant_folder+PatternMatcher([
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
@@ -41,31 +47,14 @@ ptx_matcher = constant_folder+PatternMatcher([
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"),
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
for op in asm_for_op.keys() if op not in supports_half],
# fix the gates for load/store (low quality!)
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat.var("x"),UPat.var("y"),UPat.var("z"),UPat.var("k"))),
lambda root,x,y,z,k: UOp(root.op, dtypes.uint8, (x,y,z.cast(dtypes.uint8),k)).cast(dtypes.bool)),
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
lambda root: UOp(root.op, dtypes.uint8, root.src, root.arg).cast(dtypes.bool)),
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat.var("z", dtypes.bool), UPat())),
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat.var("z", dtypes.bool))),
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat.var("g", dtypes.int))),
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)),
# ptr_ar (load/store)
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)),
UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat.var("alu"), UPat.cvar("const")]))),
lambda root, alu, const: UOp(root.op, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
const*root.src[0].dtype.itemsize)+root.src[2:])),
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.cvar("const"))),
lambda root, const: UOp(root.op, root.dtype,
(root.src[0].cast(dtypes.int64),
UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])),
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.var("alu"))),
lambda root, alu: UOp(root.op, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
UOp.const(dtypes.int64, 0))+root.src[2:])),
# load/store bool -> uint8
(UPat(UOps.LOAD, dtypes.bool, name="x"),
lambda x: UOp(x.op, dtypes.uint8, x.src[0:2] + ((x.src[2].cast(dtypes.uint8),) if len(x.src) >= 3 else ()) + x.src[3:]).cast(dtypes.bool)),
(UPat(UOps.STORE, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.void, x.src[0:2] + (x.src[2].cast(dtypes.uint8),) + x.src[3:])),
# load/store use pointer arithmetic
(UPat((UOps.LOAD, UOps.STORE), name="x", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL), name="buf"),
UPat.any(UPat.var("alu")+UPat.cvar("const"), UPat.cvar("const"), UPat.var("alu")))), load_store_ptr_arithmetic),
])
class PTXRenderer(Renderer):

View File

@@ -7,61 +7,25 @@ from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.renderer import Renderer, TensorCore
def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str:
sidx = strip_parens(r[load.src[1]]) if load.src[1].arg == BinaryOps.ADD else r[load.src[1]]
if load.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
val = f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(load.dtype)}*)({r[buf]}+{sidx}))"
else:
val = f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]"
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(load.src) > 3 and load.src[3].op is UOps.ALU: val = r.code_for_op[TernaryOps.WHERE](r[load.src[3]], val, r[load.src[2]], load.dtype)
return val
def render_store(r:CStyleLanguage, buf:UOp, idx:UOp, var:UOp, gate:Optional[UOp]=None) -> str:
def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType):
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
if var.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
prefix = r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix
val = f"*(({prefix}{r.render_dtype(var.dtype)}*)({r[buf]}+{sidx})) = {r[var]};"
else:
val = f"*({r[buf]}+{sidx}) = {r[var]};" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}] = {r[var]};"
# TODO: this if should be in UOps, not here
return f"if ({r[gate]}) {{ {val} }}" if gate is not None else val
if dtype.count > 1 and isinstance(buf.dtype, PtrDType):
return f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))"
return f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]"
def render_alu(r:CStyleLanguage, x:UOp):
if x.arg in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == x.arg else r[v] for v in x.src]
else: operands = [r[v] for v in x.src]
return r.code_for_op[x.arg](*operands, x.dtype)
def render_gep(r:CStyleLanguage, x:UOp):
from_ssa = x.src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
return (r[x.src[0]] if from_ssa else f"{(r[x.src[0]])}") + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) \
or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")
base_pm = PatternMatcher([
base_rewrite = PatternMatcher([
(UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]),
(UPat(UOps.ASSIGN, name="x"), lambda r,x: f"{r[x.src[0]]} = {r[x.src[1]]};"),
(UPat(UOps.IF, name="x"), lambda r,x: f"if ({r[x.src[0]]}) {{"),
(UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda r: "}"),
(UPat(UOps.WMMA, name="x"), lambda r,x: f"__{x.arg[0]}({r[x.src[0]]}, {r[x.src[1]]}, {r[x.src[2]]})"),
# load/store image
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))),
lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"),
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)))),
lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"),
# TODO: this if should be in UOps, not here
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4)),
UPat.var("gate", dtype=dtypes.bool))), lambda r,buf,idx,var,gate: f"if ({r[gate]}) {{ write_imagef({r[buf]}, {r[idx]}, {r[var]}); }}"),
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4))), allow_any_len=True),
lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"),
# r method accesses
(UPat(UOps.RANGE, name="x"), lambda r,x: f"for ({r.render_dtype(x.dtype)} {r[x]} = {r[x.src[0]]}; {r[x]} < {r[x.src[1]]}; {r[x]}++) {{"),
(UPat(UOps.VECTORIZE, name="x"),
lambda r,x: f"{r.float4.replace('float4', r.render_dtype(x.dtype))}" + \
(f"{{{','.join([r[y] for y in x.src])}}}" if r.device == "CLANG" else f"({','.join([r[y] for y in x.src])})")),
(UPat(UOps.CAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, False)),
(UPat(UOps.BITCAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, True)),
(UPat(UOps.CAST, name="x"), lambda r,x: f"({r.render_dtype(x.dtype)})({r[x.src[0]]})"),
(UPat(UOps.BITCAST, name="x"), lambda r,x: f"(*(({r.buffer_prefix}{r.render_dtype(x.dtype)}*)&{r[x.src[0]]}))"),
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda r,x: f"{r.smem_align}{r.smem_prefix}{r.render_dtype(x.dtype.base)} {r[x]}[{x.arg[1]}];"),
(UPat(UOps.BARRIER), lambda r: r.barrier),
(UPat(UOps.NOOP, name="x"), lambda r,x: r[x.src[0]]),
@@ -76,12 +40,18 @@ base_pm = PatternMatcher([
(UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda r,x: f"{x.arg}u"),
(UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda r,x: "1" if x.arg else "0"),
(UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)),
# function calls
(UPat(UOps.LOAD, src=(UPat.var("buf"),), allow_any_len=True, name="load"), render_load),
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate", dtype=dtypes.bool))), render_store),
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True), render_store),
(UPat(UOps.ALU, name="x"), render_alu),
(UPat(UOps.GEP, name="x"), render_gep),
# load/store
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate")), name="load"),
lambda r,buf,idx,load,var,gate: f"({r[gate]}?{_render_index(r, buf, idx, load.dtype)}:{r[var]})"),
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True, name="load"),
lambda r,buf,idx,load: _render_index(r, buf, idx, load.dtype)),
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True),
lambda r,buf,idx,var: f"{_render_index(r, buf, idx, var.dtype)} = {r[var]};"),
# alu/gep
(UPat(UOps.ALU, name="x"), lambda r,x: r.code_for_op[x.arg](
*([strip_parens(r[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else r[v] for v in x.src]), x.dtype)),
(UPat(UOps.GEP, name="x"), lambda r,x: r[x.src[0]] + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
])
extra_pm = PatternMatcher([
@@ -92,6 +62,9 @@ extra_pm = PatternMatcher([
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
(UPat(UOps.BITCAST, name="x"),
lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None),
# gate any stores that aren't gated with ifs
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(UOps.STORE, src=store.src[:3]+(UOp(UOps.IF, src=(store.src[3],)),))),
])
class CStyleLanguage(Renderer):
@@ -122,13 +95,9 @@ class CStyleLanguage(Renderer):
BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
string_rewrite = base_rewrite
extra_matcher = extra_pm
# returns a str expression of the casted xs with the given type
def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x}))"
return f"({self.render_dtype(var_dtype)})({x})"
def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
@@ -176,7 +145,7 @@ class CStyleLanguage(Renderer):
UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk")
r[u] = f"{prefix}{c[prefix]}"
l = cast(str, base_pm.rewrite(u, ctx=self))
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {UOps.ENDIF, UOps.ENDRANGE}: depth -= 1
@@ -244,8 +213,17 @@ class OpenCLRenderer(CStyleLanguage):
float4 = "(float4)"
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
def render_cast(self, x, var_dtype, bitcast=False) -> str:
return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)
string_rewrite = PatternMatcher([
(UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_{r.render_dtype(x.dtype)}({r[x.src[0]]})"),
# load/store image (OpenCL)
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))),
lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"),
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)))),
lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"),
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"),
]) + base_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
if any(uop.dtype == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
@@ -255,12 +233,14 @@ class IntelRenderer(OpenCLRenderer):
device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel "
tensor_cores = [TensorCore(dims=(8,8,16),threads=[(0,8)],dtype_in=di,dtype_out=do,reduce_axes=[(0,16)],upcast_axes=([(0,16)],[(0,16)],[(1,8)]),
st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
string_rewrite = PatternMatcher([
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda r,x: f"intel_convert_bfloat16_as_ushort({r[x[0]]})"),
(UPat(UOps.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda r,x: f"intel_convert_as_bfloat16_float({r[x[0]]})"),
]) + OpenCLRenderer.string_rewrite
def render_dtype(self, var_dtype:DType) -> str:
return f"ushort{var_dtype.count}" if "bfloat16" in var_dtype.name else super().render_dtype(var_dtype)
def render_cast(self, x, var_dtype, bitcast=False, from_dtype=None) -> str:
return f"intel_convert_bfloat16_as_ushort({x[0]})" if (var_dtype, from_dtype) == (dtypes.bfloat16, dtypes.float) else \
(f"intel_convert_as_bfloat16_float({x[0]})" if (var_dtype, from_dtype) == (dtypes.float, dtypes.bfloat16) else \
super().render_cast(x, var_dtype, bitcast))
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix = []
@@ -291,15 +271,21 @@ class MetalRenderer(CStyleLanguage):
# uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
type_map = {dtypes.bfloat16: "bfloat"}
code_for_op = {**CStyleLanguage().code_for_op,
BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype)
# precise::sin
code_for_op = {**CStyleLanguage().code_for_op, UnaryOps.SIN: lambda x,dtype: f"precise::sin({x})"}
# upcast to float32 all the ops that don't support bfloat16
extra_matcher = PatternMatcher([
# NOTE: this is copied from PTX
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))
for op in [BinaryOps.MAX, UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]]
]) + extra_pm
string_rewrite = PatternMatcher([
(UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_type<{r.render_dtype(x.dtype)}>({r[x.src[0]]})"),
]) + base_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])

View File

@@ -294,7 +294,8 @@ class QCOMAllocator(HCQAllocator):
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True)
if options.external_ptr: texture = QCOMBuffer(options.external_ptr, size)
else: texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True)
# Extend HCQBuffer with texture-related info.
texture.pitch, texture.real_stride, texture.desc, texture.ibo = pitch, real_stride, [0] * 16, [0] * 16
@@ -308,7 +309,7 @@ class QCOMAllocator(HCQAllocator):
return texture
return self.device._gpu_alloc(size, map_to_cpu=True)
return QCOMBuffer(options.external_ptr, size) if options.external_ptr else self.device._gpu_alloc(size, map_to_cpu=True)
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, dest_off=0, src_off=0):
while src_off < src_size:

View File

@@ -398,6 +398,21 @@ class Tensor:
"""
return Tensor._metaop(MetaOps.EMPTY, argfix(*shape), **kwargs)
@staticmethod
def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor:
"""
Exposes the pointer as a Tensor without taking ownership of the original data.
The pointer must remain valid for the entire lifetime of the created Tensor.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
"""
r = Tensor._metaop(MetaOps.EMPTY, shape, **kwargs)
r.lazydata.buffer.allocate(external_ptr=ptr)
del r.lazydata.srcs # fake realize
return r
_seed: int = int(time.time())
_device_seeds: Dict[str, int] = {}
_device_rng_counters: Dict[str, Tensor] = {}