mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
::: tinygrad.Tensor.full_like
|
||||
::: tinygrad.Tensor.zeros_like
|
||||
::: tinygrad.Tensor.ones_like
|
||||
::: tinygrad.Tensor.from_blob
|
||||
|
||||
## Creation (random)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
|
||||
|
||||
export BEAM=4
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
export BENCHMARK=10 DEBUG=2
|
||||
|
||||
@@ -5,8 +5,9 @@ export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
|
||||
|
||||
export BEAM=4
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
export WANDB=1
|
||||
|
||||
python3 examples/mlperf/model_train.py
|
||||
RUNMLPERF=1 python3 examples/mlperf/model_train.py
|
||||
@@ -6,6 +6,7 @@ export SUBMISSION_PLATFORM="tinybox_green"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
|
||||
|
||||
export BEAM=4
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
# pip install -e ".[mlperf]"
|
||||
|
||||
@@ -5,6 +5,7 @@ export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
|
||||
|
||||
export BEAM=4
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
export BENCHMARK=10 DEBUG=2
|
||||
|
||||
@@ -5,8 +5,9 @@ export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
|
||||
|
||||
export BEAM=4
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
export WANDB=1
|
||||
|
||||
python3 examples/mlperf/model_train.py
|
||||
RUNMLPERF=1 python3 examples/mlperf/model_train.py
|
||||
@@ -6,6 +6,7 @@ export SUBMISSION_PLATFORM="tinybox_red"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
|
||||
|
||||
export BEAM=4
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
# pip install -e ".[mlperf]"
|
||||
|
||||
87
extra/qcom_gpu_driver/qcom_opencl_interop.py
Normal file
87
extra/qcom_gpu_driver/qcom_opencl_interop.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import ctypes, array
|
||||
from hexdump import hexdump
|
||||
from tinygrad.runtime.ops_gpu import GPUDevice
|
||||
from tinygrad.helpers import getenv, to_mv, mv_address
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad import Tensor, TinyJit
|
||||
from tinygrad.runtime.autogen import opencl as cl
|
||||
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
# create raw opencl buffer.
|
||||
gdev = GPUDevice()
|
||||
cl_buf = cl.clCreateBuffer(gdev.context, cl.CL_MEM_READ_WRITE, 0x100, None, status := ctypes.c_int32())
|
||||
assert status.value == 0
|
||||
|
||||
# fill it with something for fun
|
||||
data = memoryview(array.array('I', [i for i in range(64)]))
|
||||
cl.clEnqueueWriteBuffer(gdev.queue, cl_buf, False, 0, 0x100, mv_address(data), 0, None, None)
|
||||
cl.clFinish(gdev.queue) # wait writes to complete
|
||||
|
||||
# get raw gpu pointer from opencl buffer.
|
||||
|
||||
## get buf desc
|
||||
hexdump(to_mv(ctypes.addressof(cl_buf), 0x40))
|
||||
cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_buf), 8).cast('Q')[0]
|
||||
|
||||
## get buf device ptr
|
||||
hexdump(to_mv(cl_buf_desc_ptr, 0x100))
|
||||
rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20] # offset 0xA0 is a raw gpu pointer.
|
||||
|
||||
# create QCOM tensor with the externally managed buffer
|
||||
x = Tensor.from_blob(rawbuf_ptr, (8, 8), dtype=dtypes.int, device='QCOM')
|
||||
y = (x + 1).numpy()
|
||||
print(y)
|
||||
|
||||
# all calculations are done, save to free the object
|
||||
cl.clReleaseMemObject(cl_buf)
|
||||
|
||||
# all together with jit
|
||||
@TinyJit
|
||||
def calc(x): return x + 2
|
||||
|
||||
for i in range(4):
|
||||
cl_buf = cl.clCreateBuffer(gdev.context, cl.CL_MEM_READ_WRITE, 2*2*4, None, status := ctypes.c_int32())
|
||||
assert status.value == 0
|
||||
data = memoryview(array.array('I', [x+i for x in range(2*2)]))
|
||||
cl.clEnqueueWriteBuffer(gdev.queue, cl_buf, False, 0, 2*2*4, mv_address(data), 0, None, None)
|
||||
cl.clFinish(gdev.queue) # wait writes to complete
|
||||
|
||||
cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_buf), 8).cast('Q')[0]
|
||||
rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20]
|
||||
|
||||
y = calc(x = Tensor.from_blob(rawbuf_ptr, (2, 2), dtype=dtypes.int, device='QCOM')).numpy()
|
||||
print(f'jit {i}\n', y)
|
||||
|
||||
# all calculations are done, save to free the object
|
||||
cl.clReleaseMemObject(cl_buf)
|
||||
|
||||
# now images!
|
||||
|
||||
h, w = 128, 128
|
||||
cl_img = cl.clCreateImage2D(gdev.context, cl.CL_MEM_READ_WRITE, cl.cl_image_format(cl.CL_RGBA, cl.CL_FLOAT), w, h, 0, None, status := ctypes.c_int32())
|
||||
assert status.value == 0
|
||||
|
||||
# fill it with something for fun
|
||||
data = memoryview(array.array('f', [i for i in range(h*w*4)]))
|
||||
cl.clEnqueueWriteImage(gdev.queue, cl_img, False, (ctypes.c_size_t * 3)(0,0,0), (ctypes.c_size_t * 3)(w,h,1), 0, 0, mv_address(data), 0, None, None)
|
||||
cl.clFinish(gdev.queue) # wait writes to complete
|
||||
|
||||
# get raw gpu pointer from opencl buffer.
|
||||
|
||||
## get buf desc
|
||||
hexdump(to_mv(ctypes.addressof(cl_img), 0x40))
|
||||
cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_img), 8).cast('Q')[0]
|
||||
|
||||
## get buf device ptr
|
||||
hexdump(to_mv(cl_buf_desc_ptr, 0x100))
|
||||
rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20] # offset 0xA0 is a raw gpu pointer.
|
||||
|
||||
# create QCOM tensor with the externally managed buffer
|
||||
# dtypes.imageh = cl.cl_image_format(cl.CL_RGBA, cl.CL_HALF_FLOAT)
|
||||
# dtypes.imagef = cl.cl_image_format(cl.CL_RGBA, cl.CL_FLOAT)
|
||||
x = Tensor.from_blob(rawbuf_ptr, (h*w*4,), dtype=dtypes.imagef((h,w)), device='QCOM')
|
||||
y = (x + 1).numpy()
|
||||
print(y)
|
||||
|
||||
# all calculations are done, save to free the object
|
||||
cl.clReleaseMemObject(cl_img)
|
||||
@@ -42,10 +42,9 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
Device[Device.DEFAULT].compiler.compile_cached(prg.src)
|
||||
with self.assertRaises(AssertionError):
|
||||
gate_count = len([x for x in prg.src.splitlines() if "if" in x])
|
||||
assert gate_count == 1, f"must have only one gate {gate_count} != 1"
|
||||
assert len([u for u in k.uops if u.op is UOps.IF]) == 1, "must have a single IF"
|
||||
gate_count = len([x for x in prg.src.splitlines() if "if" in x])
|
||||
assert gate_count == 1, f"must have only one gate {gate_count} != 1"
|
||||
assert len([u for u in k.uops if u.op is UOps.IF]) == 1, "must have a single IF"
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
|
||||
def test_max_simplify_and_cancel(self):
|
||||
@@ -95,7 +94,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
if_uops = [u for u in k.uops if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertIn(len(if_uops), {1,3})
|
||||
conditions = if_uops[0].src[0].sparents
|
||||
self.assertLessEqual(len(conditions), 9)
|
||||
|
||||
|
||||
@@ -1043,7 +1043,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
k = helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
assert k is not None
|
||||
ifs = [u for u in k.uops if u.op is UOps.IF]
|
||||
self.assertEqual(len(ifs), 1)
|
||||
self.assertEqual(len(ifs), 4)
|
||||
#for st in k.uops.sink.src: self.assertEqual(len(st.src), 4)
|
||||
self.assertLessEqual(len(ifs[0].src[0].sparents), 17)
|
||||
|
||||
|
||||
@@ -1825,33 +1825,21 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="linear", align_corners=True),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True))
|
||||
|
||||
def test_interpolate_nearest(self):
|
||||
for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]:
|
||||
def test_interpolate_nearest(self, mode="nearest"):
|
||||
for in_sz, out_sz in [((13,),(9,)), ((9,),(13,))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest"),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest"))
|
||||
for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]:
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode=mode),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode=mode))
|
||||
for in_sz, out_sz in [((13,10),(9,11)), ((13,9),(11,10)), ((9,11),(10,13))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest"),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest"))
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode=mode),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode=mode))
|
||||
for in_sz, out_sz in [((5,2,8),(3,6,4))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest"),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest"))
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode=mode),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode=mode))
|
||||
|
||||
def test_interpolate_nearest_exact(self):
|
||||
for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest-exact"),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest-exact"))
|
||||
for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest-exact"),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest-exact"))
|
||||
for in_sz, out_sz in [((5,2,8),(3,6,4))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest-exact"),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest-exact"))
|
||||
def test_interpolate_nearest_exact(self): self.test_interpolate_nearest("nearest-exact")
|
||||
|
||||
def test_interpolate_bilinear(self):
|
||||
for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest, copy, mmap, random, math
|
||||
import unittest, copy, mmap, random, math, array
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.helpers import getenv, temp, CI, _METADATA
|
||||
from tinygrad.helpers import getenv, temp, CI, _METADATA, mv_address
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import is_dtype_supported
|
||||
@@ -330,6 +330,17 @@ class TestTinygrad(unittest.TestCase):
|
||||
assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 # check if ndarray correctly casts to Tensor dtype
|
||||
assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 # check that it works for something else
|
||||
|
||||
def test_tensor_from_blob(self):
|
||||
x = memoryview(bytearray(16)).cast('I')
|
||||
|
||||
t = Tensor.from_blob(mv_address(x), (4,), dtype=dtypes.int, device="CLANG")
|
||||
z = (t+1)
|
||||
np.testing.assert_equal(z.numpy(), [1, 1, 1, 1])
|
||||
|
||||
x[:] = array.array('I', [0, 1, 2, 3])
|
||||
z = (t+1)
|
||||
np.testing.assert_equal(z.numpy(), [1, 2, 3, 4])
|
||||
|
||||
def test_tensor_list_dtype(self):
|
||||
for arr in ([1], [[[1]]], [[1,1],[1,1]], [[[1,1],[1,1]],[[1,1],[1,1]]]):
|
||||
assert Tensor(arr).dtype == dtypes.default_int
|
||||
|
||||
@@ -784,6 +784,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
# prefer uops that are loop children
|
||||
else:
|
||||
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss])
|
||||
if u.op is UOps.IF and len(u.src) == 1: priority += 10000000 # if penalty
|
||||
return priority
|
||||
priorities:Dict[UOp, int] = {u:get_priority(u) for u in children}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import multiprocessing, decimal, statistics, random
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type
|
||||
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
|
||||
@@ -48,6 +48,7 @@ class BufferOptions:
|
||||
cpu_access: bool = False
|
||||
host: bool = False
|
||||
nolru: bool = False
|
||||
external_ptr: Optional[int] = None
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
|
||||
@@ -75,9 +76,11 @@ class Buffer:
|
||||
def ref(self, cnt): self.base._lb_refcount += cnt
|
||||
def is_allocated(self) -> bool: return hasattr(self, '_buf')
|
||||
def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
|
||||
def allocate(self, opaque=None) -> Buffer:
|
||||
def allocate(self, opaque=None, external_ptr=None) -> Buffer:
|
||||
assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
|
||||
self.allocator = Device[self.device].allocator
|
||||
if external_ptr is not None:
|
||||
self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferOptions(external_ptr=external_ptr)
|
||||
if self._base is not None:
|
||||
self._base.ensure_allocated()
|
||||
assert hasattr(self.allocator, "offset"), "offset function required for view"
|
||||
@@ -99,7 +102,7 @@ class Buffer:
|
||||
def nbytes(self): return self.size*self.dtype.itemsize
|
||||
def __del__(self):
|
||||
if not hasattr(self, '_buf'): return
|
||||
if self._base is None:
|
||||
if self._base is None and (self.options is None or self.options.external_ptr is None):
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
|
||||
self.allocator.free(self._buf, self.nbytes, self.options)
|
||||
def __repr__(self):
|
||||
@@ -162,7 +165,8 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
|
||||
else: super().free(opaque, size, options)
|
||||
|
||||
class _MallocAllocator(LRUAllocator):
|
||||
def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
|
||||
def _alloc(self, size:int, options:BufferOptions):
|
||||
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else (ctypes.c_uint8 * size)()
|
||||
def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
|
||||
def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
|
||||
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
|
||||
|
||||
@@ -594,6 +594,7 @@ spec = PatternMatcher([
|
||||
# STORE takes a <buf, idx, val, gate?>
|
||||
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat())), lambda: True),
|
||||
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(UOps.IF))), lambda: True),
|
||||
|
||||
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
||||
(UPat(UOps.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE),
|
||||
@@ -614,7 +615,8 @@ spec = PatternMatcher([
|
||||
(UPat(UOps.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(UOps.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# if has a <gate, barrier>
|
||||
# if has a <gate, barrier?>
|
||||
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
|
||||
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True),
|
||||
(UPat(UOps.ENDIF, dtype=dtypes.void, src=(UPat(UOps.IF),)), lambda: True),
|
||||
|
||||
|
||||
@@ -32,6 +32,12 @@ asm_for_op: Dict[Op, Callable] = {
|
||||
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
||||
}
|
||||
|
||||
def load_store_ptr_arithmetic(x:UOp, buf:UOp, alu:Optional[UOp]=None, const:Optional[UOp]=None) -> UOp:
|
||||
src = list(x.src)
|
||||
src[0] = buf.cast(dtypes.int64) if alu is None else (buf.cast(dtypes.int64) + alu.cast(dtypes.int64)*buf.dtype.itemsize)
|
||||
src[1] = UOp.const(dtypes.int64, 0 if const is None else const.arg*buf.dtype.itemsize)
|
||||
return x.replace(src=tuple(src))
|
||||
|
||||
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
ptx_matcher = constant_folder+PatternMatcher([
|
||||
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
||||
@@ -41,31 +47,14 @@ ptx_matcher = constant_folder+PatternMatcher([
|
||||
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
|
||||
for op in asm_for_op.keys() if op not in supports_half],
|
||||
# fix the gates for load/store (low quality!)
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat.var("x"),UPat.var("y"),UPat.var("z"),UPat.var("k"))),
|
||||
lambda root,x,y,z,k: UOp(root.op, dtypes.uint8, (x,y,z.cast(dtypes.uint8),k)).cast(dtypes.bool)),
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
|
||||
lambda root: UOp(root.op, dtypes.uint8, root.src, root.arg).cast(dtypes.bool)),
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat.var("z", dtypes.bool), UPat())),
|
||||
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat.var("z", dtypes.bool))),
|
||||
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat.var("g", dtypes.int))),
|
||||
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)),
|
||||
# ptr_ar (load/store)
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)),
|
||||
UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat.var("alu"), UPat.cvar("const")]))),
|
||||
lambda root, alu, const: UOp(root.op, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
||||
const*root.src[0].dtype.itemsize)+root.src[2:])),
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.cvar("const"))),
|
||||
lambda root, const: UOp(root.op, root.dtype,
|
||||
(root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])),
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.var("alu"))),
|
||||
lambda root, alu: UOp(root.op, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, 0))+root.src[2:])),
|
||||
# load/store bool -> uint8
|
||||
(UPat(UOps.LOAD, dtypes.bool, name="x"),
|
||||
lambda x: UOp(x.op, dtypes.uint8, x.src[0:2] + ((x.src[2].cast(dtypes.uint8),) if len(x.src) >= 3 else ()) + x.src[3:]).cast(dtypes.bool)),
|
||||
(UPat(UOps.STORE, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.void, x.src[0:2] + (x.src[2].cast(dtypes.uint8),) + x.src[3:])),
|
||||
# load/store use pointer arithmetic
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="x", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL), name="buf"),
|
||||
UPat.any(UPat.var("alu")+UPat.cvar("const"), UPat.cvar("const"), UPat.var("alu")))), load_store_ptr_arithmetic),
|
||||
])
|
||||
|
||||
class PTXRenderer(Renderer):
|
||||
|
||||
@@ -7,61 +7,25 @@ from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
|
||||
def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str:
|
||||
sidx = strip_parens(r[load.src[1]]) if load.src[1].arg == BinaryOps.ADD else r[load.src[1]]
|
||||
if load.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
val = f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(load.dtype)}*)({r[buf]}+{sidx}))"
|
||||
else:
|
||||
val = f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]"
|
||||
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(load.src) > 3 and load.src[3].op is UOps.ALU: val = r.code_for_op[TernaryOps.WHERE](r[load.src[3]], val, r[load.src[2]], load.dtype)
|
||||
return val
|
||||
|
||||
def render_store(r:CStyleLanguage, buf:UOp, idx:UOp, var:UOp, gate:Optional[UOp]=None) -> str:
|
||||
def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType):
|
||||
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
|
||||
if var.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
prefix = r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix
|
||||
val = f"*(({prefix}{r.render_dtype(var.dtype)}*)({r[buf]}+{sidx})) = {r[var]};"
|
||||
else:
|
||||
val = f"*({r[buf]}+{sidx}) = {r[var]};" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}] = {r[var]};"
|
||||
# TODO: this if should be in UOps, not here
|
||||
return f"if ({r[gate]}) {{ {val} }}" if gate is not None else val
|
||||
if dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
return f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))"
|
||||
return f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]"
|
||||
|
||||
def render_alu(r:CStyleLanguage, x:UOp):
|
||||
if x.arg in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == x.arg else r[v] for v in x.src]
|
||||
else: operands = [r[v] for v in x.src]
|
||||
return r.code_for_op[x.arg](*operands, x.dtype)
|
||||
|
||||
def render_gep(r:CStyleLanguage, x:UOp):
|
||||
from_ssa = x.src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
return (r[x.src[0]] if from_ssa else f"{(r[x.src[0]])}") + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) \
|
||||
or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")
|
||||
|
||||
base_pm = PatternMatcher([
|
||||
base_rewrite = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]),
|
||||
(UPat(UOps.ASSIGN, name="x"), lambda r,x: f"{r[x.src[0]]} = {r[x.src[1]]};"),
|
||||
(UPat(UOps.IF, name="x"), lambda r,x: f"if ({r[x.src[0]]}) {{"),
|
||||
(UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda r: "}"),
|
||||
(UPat(UOps.WMMA, name="x"), lambda r,x: f"__{x.arg[0]}({r[x.src[0]]}, {r[x.src[1]]}, {r[x.src[2]]})"),
|
||||
# load/store image
|
||||
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"),
|
||||
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)))),
|
||||
lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"),
|
||||
# TODO: this if should be in UOps, not here
|
||||
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4)),
|
||||
UPat.var("gate", dtype=dtypes.bool))), lambda r,buf,idx,var,gate: f"if ({r[gate]}) {{ write_imagef({r[buf]}, {r[idx]}, {r[var]}); }}"),
|
||||
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"),
|
||||
# r method accesses
|
||||
(UPat(UOps.RANGE, name="x"), lambda r,x: f"for ({r.render_dtype(x.dtype)} {r[x]} = {r[x.src[0]]}; {r[x]} < {r[x.src[1]]}; {r[x]}++) {{"),
|
||||
(UPat(UOps.VECTORIZE, name="x"),
|
||||
lambda r,x: f"{r.float4.replace('float4', r.render_dtype(x.dtype))}" + \
|
||||
(f"{{{','.join([r[y] for y in x.src])}}}" if r.device == "CLANG" else f"({','.join([r[y] for y in x.src])})")),
|
||||
(UPat(UOps.CAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, False)),
|
||||
(UPat(UOps.BITCAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, True)),
|
||||
(UPat(UOps.CAST, name="x"), lambda r,x: f"({r.render_dtype(x.dtype)})({r[x.src[0]]})"),
|
||||
(UPat(UOps.BITCAST, name="x"), lambda r,x: f"(*(({r.buffer_prefix}{r.render_dtype(x.dtype)}*)&{r[x.src[0]]}))"),
|
||||
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda r,x: f"{r.smem_align}{r.smem_prefix}{r.render_dtype(x.dtype.base)} {r[x]}[{x.arg[1]}];"),
|
||||
(UPat(UOps.BARRIER), lambda r: r.barrier),
|
||||
(UPat(UOps.NOOP, name="x"), lambda r,x: r[x.src[0]]),
|
||||
@@ -76,12 +40,18 @@ base_pm = PatternMatcher([
|
||||
(UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda r,x: f"{x.arg}u"),
|
||||
(UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda r,x: "1" if x.arg else "0"),
|
||||
(UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)),
|
||||
# function calls
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"),), allow_any_len=True, name="load"), render_load),
|
||||
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate", dtype=dtypes.bool))), render_store),
|
||||
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True), render_store),
|
||||
(UPat(UOps.ALU, name="x"), render_alu),
|
||||
(UPat(UOps.GEP, name="x"), render_gep),
|
||||
# load/store
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate")), name="load"),
|
||||
lambda r,buf,idx,load,var,gate: f"({r[gate]}?{_render_index(r, buf, idx, load.dtype)}:{r[var]})"),
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True, name="load"),
|
||||
lambda r,buf,idx,load: _render_index(r, buf, idx, load.dtype)),
|
||||
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True),
|
||||
lambda r,buf,idx,var: f"{_render_index(r, buf, idx, var.dtype)} = {r[var]};"),
|
||||
# alu/gep
|
||||
(UPat(UOps.ALU, name="x"), lambda r,x: r.code_for_op[x.arg](
|
||||
*([strip_parens(r[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else r[v] for v in x.src]), x.dtype)),
|
||||
(UPat(UOps.GEP, name="x"), lambda r,x: r[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
])
|
||||
|
||||
extra_pm = PatternMatcher([
|
||||
@@ -92,6 +62,9 @@ extra_pm = PatternMatcher([
|
||||
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
||||
(UPat(UOps.BITCAST, name="x"),
|
||||
lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None),
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
||||
lambda store: UOp(UOps.STORE, src=store.src[:3]+(UOp(UOps.IF, src=(store.src[3],)),))),
|
||||
])
|
||||
|
||||
class CStyleLanguage(Renderer):
|
||||
@@ -122,13 +95,9 @@ class CStyleLanguage(Renderer):
|
||||
BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
||||
|
||||
string_rewrite = base_rewrite
|
||||
extra_matcher = extra_pm
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
|
||||
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x}))"
|
||||
return f"({self.render_dtype(var_dtype)})({x})"
|
||||
|
||||
def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
||||
@@ -176,7 +145,7 @@ class CStyleLanguage(Renderer):
|
||||
UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk")
|
||||
r[u] = f"{prefix}{c[prefix]}"
|
||||
|
||||
l = cast(str, base_pm.rewrite(u, ctx=self))
|
||||
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {UOps.ENDIF, UOps.ENDRANGE}: depth -= 1
|
||||
@@ -244,8 +213,17 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
float4 = "(float4)"
|
||||
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
|
||||
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
|
||||
def render_cast(self, x, var_dtype, bitcast=False) -> str:
|
||||
return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_{r.render_dtype(x.dtype)}({r[x.src[0]]})"),
|
||||
# load/store image (OpenCL)
|
||||
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"),
|
||||
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)))),
|
||||
lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"),
|
||||
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"),
|
||||
]) + base_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
if any(uop.dtype == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
|
||||
@@ -255,12 +233,14 @@ class IntelRenderer(OpenCLRenderer):
|
||||
device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel "
|
||||
tensor_cores = [TensorCore(dims=(8,8,16),threads=[(0,8)],dtype_in=di,dtype_out=do,reduce_axes=[(0,16)],upcast_axes=([(0,16)],[(0,16)],[(1,8)]),
|
||||
st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda r,x: f"intel_convert_bfloat16_as_ushort({r[x[0]]})"),
|
||||
(UPat(UOps.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda r,x: f"intel_convert_as_bfloat16_float({r[x[0]]})"),
|
||||
]) + OpenCLRenderer.string_rewrite
|
||||
|
||||
def render_dtype(self, var_dtype:DType) -> str:
|
||||
return f"ushort{var_dtype.count}" if "bfloat16" in var_dtype.name else super().render_dtype(var_dtype)
|
||||
def render_cast(self, x, var_dtype, bitcast=False, from_dtype=None) -> str:
|
||||
return f"intel_convert_bfloat16_as_ushort({x[0]})" if (var_dtype, from_dtype) == (dtypes.bfloat16, dtypes.float) else \
|
||||
(f"intel_convert_as_bfloat16_float({x[0]})" if (var_dtype, from_dtype) == (dtypes.float, dtypes.bfloat16) else \
|
||||
super().render_cast(x, var_dtype, bitcast))
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix = []
|
||||
@@ -291,15 +271,21 @@ class MetalRenderer(CStyleLanguage):
|
||||
# uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
|
||||
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
|
||||
type_map = {dtypes.bfloat16: "bfloat"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op,
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
|
||||
|
||||
def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
|
||||
return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype)
|
||||
# precise::sin
|
||||
code_for_op = {**CStyleLanguage().code_for_op, UnaryOps.SIN: lambda x,dtype: f"precise::sin({x})"}
|
||||
|
||||
# upcast to float32 all the ops that don't support bfloat16
|
||||
extra_matcher = PatternMatcher([
|
||||
# NOTE: this is copied from PTX
|
||||
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))
|
||||
for op in [BinaryOps.MAX, UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]]
|
||||
]) + extra_pm
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_type<{r.render_dtype(x.dtype)}>({r[x.src[0]]})"),
|
||||
]) + base_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
|
||||
|
||||
@@ -294,7 +294,8 @@ class QCOMAllocator(HCQAllocator):
|
||||
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
|
||||
pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
|
||||
|
||||
texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True)
|
||||
if options.external_ptr: texture = QCOMBuffer(options.external_ptr, size)
|
||||
else: texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True)
|
||||
|
||||
# Extend HCQBuffer with texture-related info.
|
||||
texture.pitch, texture.real_stride, texture.desc, texture.ibo = pitch, real_stride, [0] * 16, [0] * 16
|
||||
@@ -308,7 +309,7 @@ class QCOMAllocator(HCQAllocator):
|
||||
|
||||
return texture
|
||||
|
||||
return self.device._gpu_alloc(size, map_to_cpu=True)
|
||||
return QCOMBuffer(options.external_ptr, size) if options.external_ptr else self.device._gpu_alloc(size, map_to_cpu=True)
|
||||
|
||||
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, dest_off=0, src_off=0):
|
||||
while src_off < src_size:
|
||||
|
||||
@@ -398,6 +398,21 @@ class Tensor:
|
||||
"""
|
||||
return Tensor._metaop(MetaOps.EMPTY, argfix(*shape), **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor:
|
||||
"""
|
||||
Exposes the pointer as a Tensor without taking ownership of the original data.
|
||||
The pointer must remain valid for the entire lifetime of the created Tensor.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
"""
|
||||
|
||||
r = Tensor._metaop(MetaOps.EMPTY, shape, **kwargs)
|
||||
r.lazydata.buffer.allocate(external_ptr=ptr)
|
||||
del r.lazydata.srcs # fake realize
|
||||
return r
|
||||
|
||||
_seed: int = int(time.time())
|
||||
_device_seeds: Dict[str, int] = {}
|
||||
_device_rng_counters: Dict[str, Tensor] = {}
|
||||
|
||||
Reference in New Issue
Block a user