diff --git a/docs/tensor/creation.md b/docs/tensor/creation.md index 4ab4d0cfe8..2be2204e3d 100644 --- a/docs/tensor/creation.md +++ b/docs/tensor/creation.md @@ -9,6 +9,7 @@ ::: tinygrad.Tensor.full_like ::: tinygrad.Tensor.zeros_like ::: tinygrad.Tensor.ones_like +::: tinygrad.Tensor.from_blob ## Creation (random) diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh index 368d2cc9b4..98bacec1a1 100755 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh @@ -5,6 +5,7 @@ export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 export BEAM=4 +export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" export BENCHMARK=10 DEBUG=2 diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh index 672cd2f24e..a741ebb761 100755 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh @@ -5,8 +5,9 @@ export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 export BEAM=4 +export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" export WANDB=1 -python3 examples/mlperf/model_train.py \ No newline at end of file +RUNMLPERF=1 python3 examples/mlperf/model_train.py \ No newline at end of file diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh index 7c3a928242..963553857e 100755 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh @@ -6,6 +6,7 @@ export SUBMISSION_PLATFORM="tinybox_green" export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 export BEAM=4 +export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" # pip install -e ".[mlperf]" diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh index 368d2cc9b4..98bacec1a1 100644 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh @@ -5,6 +5,7 @@ export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 export BEAM=4 +export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" export BENCHMARK=10 DEBUG=2 diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh index 672cd2f24e..a741ebb761 100644 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh @@ -5,8 +5,9 @@ export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 export BEAM=4 +export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" export WANDB=1 -python3 examples/mlperf/model_train.py \ No newline at end of file +RUNMLPERF=1 python3 examples/mlperf/model_train.py \ No newline at end of file diff --git a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh index e6036c027b..5b80731948 100644 --- a/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh +++ b/examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh @@ -6,6 +6,7 @@ export SUBMISSION_PLATFORM="tinybox_red" export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6 export BEAM=4 +export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" # pip install -e ".[mlperf]" diff --git a/extra/qcom_gpu_driver/qcom_opencl_interop.py b/extra/qcom_gpu_driver/qcom_opencl_interop.py new file mode 100644 index 0000000000..d595ba343f --- /dev/null +++ b/extra/qcom_gpu_driver/qcom_opencl_interop.py @@ -0,0 +1,87 @@ +import ctypes, array +from hexdump import hexdump +from tinygrad.runtime.ops_gpu import GPUDevice +from tinygrad.helpers import getenv, to_mv, mv_address +from tinygrad.dtype import dtypes +from tinygrad import Tensor, TinyJit +from tinygrad.runtime.autogen import opencl as cl +if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import + +# create raw opencl buffer. +gdev = GPUDevice() +cl_buf = cl.clCreateBuffer(gdev.context, cl.CL_MEM_READ_WRITE, 0x100, None, status := ctypes.c_int32()) +assert status.value == 0 + +# fill it with something for fun +data = memoryview(array.array('I', [i for i in range(64)])) +cl.clEnqueueWriteBuffer(gdev.queue, cl_buf, False, 0, 0x100, mv_address(data), 0, None, None) +cl.clFinish(gdev.queue) # wait writes to complete + +# get raw gpu pointer from opencl buffer. + +## get buf desc +hexdump(to_mv(ctypes.addressof(cl_buf), 0x40)) +cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_buf), 8).cast('Q')[0] + +## get buf device ptr +hexdump(to_mv(cl_buf_desc_ptr, 0x100)) +rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20] # offset 0xA0 is a raw gpu pointer. + +# create QCOM tensor with the externally managed buffer +x = Tensor.from_blob(rawbuf_ptr, (8, 8), dtype=dtypes.int, device='QCOM') +y = (x + 1).numpy() +print(y) + +# all calculations are done, save to free the object +cl.clReleaseMemObject(cl_buf) + +# all together with jit +@TinyJit +def calc(x): return x + 2 + +for i in range(4): + cl_buf = cl.clCreateBuffer(gdev.context, cl.CL_MEM_READ_WRITE, 2*2*4, None, status := ctypes.c_int32()) + assert status.value == 0 + data = memoryview(array.array('I', [x+i for x in range(2*2)])) + cl.clEnqueueWriteBuffer(gdev.queue, cl_buf, False, 0, 2*2*4, mv_address(data), 0, None, None) + cl.clFinish(gdev.queue) # wait writes to complete + + cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_buf), 8).cast('Q')[0] + rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20] + + y = calc(x = Tensor.from_blob(rawbuf_ptr, (2, 2), dtype=dtypes.int, device='QCOM')).numpy() + print(f'jit {i}\n', y) + + # all calculations are done, save to free the object + cl.clReleaseMemObject(cl_buf) + +# now images! + +h, w = 128, 128 +cl_img = cl.clCreateImage2D(gdev.context, cl.CL_MEM_READ_WRITE, cl.cl_image_format(cl.CL_RGBA, cl.CL_FLOAT), w, h, 0, None, status := ctypes.c_int32()) +assert status.value == 0 + +# fill it with something for fun +data = memoryview(array.array('f', [i for i in range(h*w*4)])) +cl.clEnqueueWriteImage(gdev.queue, cl_img, False, (ctypes.c_size_t * 3)(0,0,0), (ctypes.c_size_t * 3)(w,h,1), 0, 0, mv_address(data), 0, None, None) +cl.clFinish(gdev.queue) # wait writes to complete + +# get raw gpu pointer from opencl buffer. + +## get buf desc +hexdump(to_mv(ctypes.addressof(cl_img), 0x40)) +cl_buf_desc_ptr = to_mv(ctypes.addressof(cl_img), 8).cast('Q')[0] + +## get buf device ptr +hexdump(to_mv(cl_buf_desc_ptr, 0x100)) +rawbuf_ptr = to_mv(cl_buf_desc_ptr, 0x100).cast('Q')[20] # offset 0xA0 is a raw gpu pointer. + +# create QCOM tensor with the externally managed buffer +# dtypes.imageh = cl.cl_image_format(cl.CL_RGBA, cl.CL_HALF_FLOAT) +# dtypes.imagef = cl.cl_image_format(cl.CL_RGBA, cl.CL_FLOAT) +x = Tensor.from_blob(rawbuf_ptr, (h*w*4,), dtype=dtypes.imagef((h,w)), device='QCOM') +y = (x + 1).numpy() +print(y) + +# all calculations are done, save to free the object +cl.clReleaseMemObject(cl_img) diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index 8bd9a62190..ce3ec109ad 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -42,10 +42,9 @@ class TestLinearizerDumb(unittest.TestCase): prg = k.to_program() print(prg.src) Device[Device.DEFAULT].compiler.compile_cached(prg.src) - with self.assertRaises(AssertionError): - gate_count = len([x for x in prg.src.splitlines() if "if" in x]) - assert gate_count == 1, f"must have only one gate {gate_count} != 1" - assert len([u for u in k.uops if u.op is UOps.IF]) == 1, "must have a single IF" + gate_count = len([x for x in prg.src.splitlines() if "if" in x]) + assert gate_count == 1, f"must have only one gate {gate_count} != 1" + assert len([u for u in k.uops if u.op is UOps.IF]) == 1, "must have a single IF" @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local") def test_max_simplify_and_cancel(self): @@ -95,7 +94,7 @@ class TestLinearizerDumb(unittest.TestCase): prg = k.to_program() print(prg.src) if_uops = [u for u in k.uops if u.op is UOps.IF] - self.assertEqual(len(if_uops), 1) + self.assertIn(len(if_uops), {1,3}) conditions = if_uops[0].src[0].sparents self.assertLessEqual(len(conditions), 9) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index b305169642..80f97f1222 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1043,7 +1043,7 @@ class TestLinearizerFailures(unittest.TestCase): k = helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) assert k is not None ifs = [u for u in k.uops if u.op is UOps.IF] - self.assertEqual(len(ifs), 1) + self.assertEqual(len(ifs), 4) #for st in k.uops.sink.src: self.assertEqual(len(st.src), 4) self.assertLessEqual(len(ifs[0].src[0].sparents), 17) diff --git a/test/test_ops.py b/test/test_ops.py index d9713781f3..342b27a415 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1825,33 +1825,21 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="linear", align_corners=True), lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True)) - def test_interpolate_nearest(self): - for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]: + def test_interpolate_nearest(self, mode="nearest"): + for in_sz, out_sz in [((13,),(9,)), ((9,),(13,))]: helper_test_op([(2,3)+in_sz], - lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest"), - lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest")) - for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]: + lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode=mode), + lambda x: Tensor.interpolate(x, size=out_sz, mode=mode)) + for in_sz, out_sz in [((13,10),(9,11)), ((13,9),(11,10)), ((9,11),(10,13))]: helper_test_op([(2,3)+in_sz], - lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest"), - lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest")) + lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode=mode), + lambda x: Tensor.interpolate(x, size=out_sz, mode=mode)) for in_sz, out_sz in [((5,2,8),(3,6,4))]: helper_test_op([(2,3)+in_sz], - lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest"), - lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest")) + lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode=mode), + lambda x: Tensor.interpolate(x, size=out_sz, mode=mode)) - def test_interpolate_nearest_exact(self): - for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]: - helper_test_op([(2,3)+in_sz], - lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest-exact"), - lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest-exact")) - for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]: - helper_test_op([(2,3)+in_sz], - lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest-exact"), - lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest-exact")) - for in_sz, out_sz in [((5,2,8),(3,6,4))]: - helper_test_op([(2,3)+in_sz], - lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="nearest-exact"), - lambda x: Tensor.interpolate(x, size=out_sz, mode="nearest-exact")) + def test_interpolate_nearest_exact(self): self.test_interpolate_nearest("nearest-exact") def test_interpolate_bilinear(self): for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]: diff --git a/test/test_tensor.py b/test/test_tensor.py index 1af4dad263..1f831e91bc 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1,10 +1,10 @@ import subprocess import numpy as np import torch -import unittest, copy, mmap, random, math +import unittest, copy, mmap, random, math, array from tinygrad import Tensor, Device, dtypes from tinygrad.engine.schedule import create_schedule -from tinygrad.helpers import getenv, temp, CI, _METADATA +from tinygrad.helpers import getenv, temp, CI, _METADATA, mv_address from extra.gradcheck import numerical_jacobian, jacobian, gradcheck from hypothesis import given, settings, strategies as strat from test.helpers import is_dtype_supported @@ -330,6 +330,17 @@ class TestTinygrad(unittest.TestCase): assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32 # check if ndarray correctly casts to Tensor dtype assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64 # check that it works for something else + def test_tensor_from_blob(self): + x = memoryview(bytearray(16)).cast('I') + + t = Tensor.from_blob(mv_address(x), (4,), dtype=dtypes.int, device="CLANG") + z = (t+1) + np.testing.assert_equal(z.numpy(), [1, 1, 1, 1]) + + x[:] = array.array('I', [0, 1, 2, 3]) + z = (t+1) + np.testing.assert_equal(z.numpy(), [1, 2, 3, 4]) + def test_tensor_list_dtype(self): for arr in ([1], [[[1]]], [[1,1],[1,1]], [[[1,1],[1,1]],[[1,1],[1,1]]]): assert Tensor(arr).dtype == dtypes.default_int diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 18cffefbbd..6c87561f63 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -784,6 +784,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: # prefer uops that are loop children else: priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss]) + if u.op is UOps.IF and len(u.src) == 1: priority += 10000000 # if penalty return priority priorities:Dict[UOp, int] = {u:get_priority(u) for u in children} diff --git a/tinygrad/device.py b/tinygrad/device.py index 4de4d5557d..52e31015ac 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -1,6 +1,6 @@ from __future__ import annotations import multiprocessing, decimal, statistics, random -from dataclasses import dataclass +from dataclasses import dataclass, replace from collections import defaultdict from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array @@ -48,6 +48,7 @@ class BufferOptions: cpu_access: bool = False host: bool = False nolru: bool = False + external_ptr: Optional[int] = None class Buffer: def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None, @@ -75,9 +76,11 @@ class Buffer: def ref(self, cnt): self.base._lb_refcount += cnt def is_allocated(self) -> bool: return hasattr(self, '_buf') def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self - def allocate(self, opaque=None) -> Buffer: + def allocate(self, opaque=None, external_ptr=None) -> Buffer: assert not hasattr(self, '_buf'), "can't allocate already allocated buffer" self.allocator = Device[self.device].allocator + if external_ptr is not None: + self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferOptions(external_ptr=external_ptr) if self._base is not None: self._base.ensure_allocated() assert hasattr(self.allocator, "offset"), "offset function required for view" @@ -99,7 +102,7 @@ class Buffer: def nbytes(self): return self.size*self.dtype.itemsize def __del__(self): if not hasattr(self, '_buf'): return - if self._base is None: + if self._base is None and (self.options is None or self.options.external_ptr is None): if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes self.allocator.free(self._buf, self.nbytes, self.options) def __repr__(self): @@ -162,7 +165,8 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method else: super().free(opaque, size, options) class _MallocAllocator(LRUAllocator): - def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)() + def _alloc(self, size:int, options:BufferOptions): + return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else (ctypes.c_uint8 * size)() def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src)) def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src)) def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index fcb2f7eaf5..c41a71d0f1 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -594,6 +594,7 @@ spec = PatternMatcher([ # STORE takes a (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat())), lambda: True), (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(dtype=dtypes.bool))), lambda: True), + (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(UOps.IF))), lambda: True), # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE (UPat(UOps.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE), @@ -614,7 +615,8 @@ spec = PatternMatcher([ (UPat(UOps.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), (UPat(UOps.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), - # if has a + # if has a + (UPat(UOps.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True), (UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True), (UPat(UOps.ENDIF, dtype=dtypes.void, src=(UPat(UOps.IF),)), lambda: True), diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index bdb0719fbc..bcf76684d1 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -32,6 +32,12 @@ asm_for_op: Dict[Op, Callable] = { f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};" } +def load_store_ptr_arithmetic(x:UOp, buf:UOp, alu:Optional[UOp]=None, const:Optional[UOp]=None) -> UOp: + src = list(x.src) + src[0] = buf.cast(dtypes.int64) if alu is None else (buf.cast(dtypes.int64) + alu.cast(dtypes.int64)*buf.dtype.itemsize) + src[1] = UOp.const(dtypes.int64, 0 if const is None else const.arg*buf.dtype.itemsize) + return x.replace(src=tuple(src)) + supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE] ptx_matcher = constant_folder+PatternMatcher([ # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only) @@ -41,31 +47,14 @@ ptx_matcher = constant_folder+PatternMatcher([ *[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"), lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))) for op in asm_for_op.keys() if op not in supports_half], - # fix the gates for load/store (low quality!) - (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat.var("x"),UPat.var("y"),UPat.var("z"),UPat.var("k"))), - lambda root,x,y,z,k: UOp(root.op, dtypes.uint8, (x,y,z.cast(dtypes.uint8),k)).cast(dtypes.bool)), - (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())), - lambda root: UOp(root.op, dtypes.uint8, root.src, root.arg).cast(dtypes.bool)), - (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat.var("z", dtypes.bool), UPat())), - lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)), - (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat.var("z", dtypes.bool))), - lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)), - (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat.var("g", dtypes.int))), - lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)), - # ptr_ar (load/store) - (UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), - UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat.var("alu"), UPat.cvar("const")]))), - lambda root, alu, const: UOp(root.op, root.dtype, - (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64), - const*root.src[0].dtype.itemsize)+root.src[2:])), - (UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.cvar("const"))), - lambda root, const: UOp(root.op, root.dtype, - (root.src[0].cast(dtypes.int64), - UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])), - (UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.var("alu"))), - lambda root, alu: UOp(root.op, root.dtype, - (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64), - UOp.const(dtypes.int64, 0))+root.src[2:])), + # load/store bool -> uint8 + (UPat(UOps.LOAD, dtypes.bool, name="x"), + lambda x: UOp(x.op, dtypes.uint8, x.src[0:2] + ((x.src[2].cast(dtypes.uint8),) if len(x.src) >= 3 else ()) + x.src[3:]).cast(dtypes.bool)), + (UPat(UOps.STORE, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True), + lambda x: UOp(x.op, dtypes.void, x.src[0:2] + (x.src[2].cast(dtypes.uint8),) + x.src[3:])), + # load/store use pointer arithmetic + (UPat((UOps.LOAD, UOps.STORE), name="x", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL), name="buf"), + UPat.any(UPat.var("alu")+UPat.cvar("const"), UPat.cvar("const"), UPat.var("alu")))), load_store_ptr_arithmetic), ]) class PTXRenderer(Renderer): diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 0ec5582f62..5cf12e94c0 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -7,61 +7,25 @@ from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType from tinygrad.renderer import Renderer, TensorCore -def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str: - sidx = strip_parens(r[load.src[1]]) if load.src[1].arg == BinaryOps.ADD else r[load.src[1]] - if load.dtype.count > 1 and isinstance(buf.dtype, PtrDType): - val = f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(load.dtype)}*)({r[buf]}+{sidx}))" - else: - val = f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]" - - # NOTE: this relies on the load not happening if it's in the unselected branch - if len(load.src) > 3 and load.src[3].op is UOps.ALU: val = r.code_for_op[TernaryOps.WHERE](r[load.src[3]], val, r[load.src[2]], load.dtype) - return val - -def render_store(r:CStyleLanguage, buf:UOp, idx:UOp, var:UOp, gate:Optional[UOp]=None) -> str: +def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType): sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx] - if var.dtype.count > 1 and isinstance(buf.dtype, PtrDType): - prefix = r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix - val = f"*(({prefix}{r.render_dtype(var.dtype)}*)({r[buf]}+{sidx})) = {r[var]};" - else: - val = f"*({r[buf]}+{sidx}) = {r[var]};" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}] = {r[var]};" - # TODO: this if should be in UOps, not here - return f"if ({r[gate]}) {{ {val} }}" if gate is not None else val + if dtype.count > 1 and isinstance(buf.dtype, PtrDType): + return f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))" + return f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]" -def render_alu(r:CStyleLanguage, x:UOp): - if x.arg in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == x.arg else r[v] for v in x.src] - else: operands = [r[v] for v in x.src] - return r.code_for_op[x.arg](*operands, x.dtype) - -def render_gep(r:CStyleLanguage, x:UOp): - from_ssa = x.src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC} - return (r[x.src[0]] if from_ssa else f"{(r[x.src[0]])}") + \ - (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) \ - or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}") - -base_pm = PatternMatcher([ +base_rewrite = PatternMatcher([ (UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]), (UPat(UOps.ASSIGN, name="x"), lambda r,x: f"{r[x.src[0]]} = {r[x.src[1]]};"), (UPat(UOps.IF, name="x"), lambda r,x: f"if ({r[x.src[0]]}) {{"), (UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda r: "}"), (UPat(UOps.WMMA, name="x"), lambda r,x: f"__{x.arg[0]}({r[x.src[0]]}, {r[x.src[1]]}, {r[x.src[2]]})"), - # load/store image - (UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))), - lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"), - (UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)))), - lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"), - # TODO: this if should be in UOps, not here - (UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4)), - UPat.var("gate", dtype=dtypes.bool))), lambda r,buf,idx,var,gate: f"if ({r[gate]}) {{ write_imagef({r[buf]}, {r[idx]}, {r[var]}); }}"), - (UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4))), allow_any_len=True), - lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"), # r method accesses (UPat(UOps.RANGE, name="x"), lambda r,x: f"for ({r.render_dtype(x.dtype)} {r[x]} = {r[x.src[0]]}; {r[x]} < {r[x.src[1]]}; {r[x]}++) {{"), (UPat(UOps.VECTORIZE, name="x"), lambda r,x: f"{r.float4.replace('float4', r.render_dtype(x.dtype))}" + \ (f"{{{','.join([r[y] for y in x.src])}}}" if r.device == "CLANG" else f"({','.join([r[y] for y in x.src])})")), - (UPat(UOps.CAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, False)), - (UPat(UOps.BITCAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, True)), + (UPat(UOps.CAST, name="x"), lambda r,x: f"({r.render_dtype(x.dtype)})({r[x.src[0]]})"), + (UPat(UOps.BITCAST, name="x"), lambda r,x: f"(*(({r.buffer_prefix}{r.render_dtype(x.dtype)}*)&{r[x.src[0]]}))"), (UPat(UOps.DEFINE_LOCAL, name="x"), lambda r,x: f"{r.smem_align}{r.smem_prefix}{r.render_dtype(x.dtype.base)} {r[x]}[{x.arg[1]}];"), (UPat(UOps.BARRIER), lambda r: r.barrier), (UPat(UOps.NOOP, name="x"), lambda r,x: r[x.src[0]]), @@ -76,12 +40,18 @@ base_pm = PatternMatcher([ (UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda r,x: f"{x.arg}u"), (UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda r,x: "1" if x.arg else "0"), (UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)), - # function calls - (UPat(UOps.LOAD, src=(UPat.var("buf"),), allow_any_len=True, name="load"), render_load), - (UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate", dtype=dtypes.bool))), render_store), - (UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True), render_store), - (UPat(UOps.ALU, name="x"), render_alu), - (UPat(UOps.GEP, name="x"), render_gep), + # load/store + (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate")), name="load"), + lambda r,buf,idx,load,var,gate: f"({r[gate]}?{_render_index(r, buf, idx, load.dtype)}:{r[var]})"), + (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True, name="load"), + lambda r,buf,idx,load: _render_index(r, buf, idx, load.dtype)), + (UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True), + lambda r,buf,idx,var: f"{_render_index(r, buf, idx, var.dtype)} = {r[var]};"), + # alu/gep + (UPat(UOps.ALU, name="x"), lambda r,x: r.code_for_op[x.arg]( + *([strip_parens(r[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else r[v] for v in x.src]), x.dtype)), + (UPat(UOps.GEP, name="x"), lambda r,x: r[x.src[0]] + \ + (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")), ]) extra_pm = PatternMatcher([ @@ -92,6 +62,9 @@ extra_pm = PatternMatcher([ # insert a NOOP before BITCAST to force it to be rendered. not needed on all backends? (UPat(UOps.BITCAST, name="x"), lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None), + # gate any stores that aren't gated with ifs + (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"), + lambda store: UOp(UOps.STORE, src=store.src[:3]+(UOp(UOps.IF, src=(store.src[3],)),))), ]) class CStyleLanguage(Renderer): @@ -122,13 +95,9 @@ class CStyleLanguage(Renderer): BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"} + string_rewrite = base_rewrite extra_matcher = extra_pm - # returns a str expression of the casted xs with the given type - def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str: - if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x}))" - return f"({self.render_dtype(var_dtype)})({x})" - def get_kernel_modifier(self, uops:List[UOp]) -> str: return "" def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501 @@ -176,7 +145,7 @@ class CStyleLanguage(Renderer): UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk") r[u] = f"{prefix}{c[prefix]}" - l = cast(str, base_pm.rewrite(u, ctx=self)) + l = cast(str, self.string_rewrite.rewrite(u, ctx=self)) assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" if u.op in {UOps.ENDIF, UOps.ENDRANGE}: depth -= 1 @@ -244,8 +213,17 @@ class OpenCLRenderer(CStyleLanguage): float4 = "(float4)" code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"} type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" } - def render_cast(self, x, var_dtype, bitcast=False) -> str: - return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype) + + string_rewrite = PatternMatcher([ + (UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_{r.render_dtype(x.dtype)}({r[x.src[0]]})"), + # load/store image (OpenCL) + (UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))), + lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"), + (UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)))), + lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"), + (UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtypes.int.vec(2)), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True), + lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"), + ]) + base_rewrite def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: if any(uop.dtype == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or [])) @@ -255,12 +233,14 @@ class IntelRenderer(OpenCLRenderer): device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel " tensor_cores = [TensorCore(dims=(8,8,16),threads=[(0,8)],dtype_in=di,dtype_out=do,reduce_axes=[(0,16)],upcast_axes=([(0,16)],[(0,16)],[(1,8)]), st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]] + + string_rewrite = PatternMatcher([ + (UPat(UOps.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda r,x: f"intel_convert_bfloat16_as_ushort({r[x[0]]})"), + (UPat(UOps.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda r,x: f"intel_convert_as_bfloat16_float({r[x[0]]})"), + ]) + OpenCLRenderer.string_rewrite + def render_dtype(self, var_dtype:DType) -> str: return f"ushort{var_dtype.count}" if "bfloat16" in var_dtype.name else super().render_dtype(var_dtype) - def render_cast(self, x, var_dtype, bitcast=False, from_dtype=None) -> str: - return f"intel_convert_bfloat16_as_ushort({x[0]})" if (var_dtype, from_dtype) == (dtypes.bfloat16, dtypes.float) else \ - (f"intel_convert_as_bfloat16_float({x[0]})" if (var_dtype, from_dtype) == (dtypes.float, dtypes.bfloat16) else \ - super().render_cast(x, var_dtype, bitcast)) def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: prefix = [] @@ -291,15 +271,21 @@ class MetalRenderer(CStyleLanguage): # uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]` extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'] type_map = {dtypes.bfloat16: "bfloat"} - code_for_op = {**CStyleLanguage().code_for_op, - BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})", - UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})", - UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})", - UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})", - UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",} - def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str: - return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype) + # precise::sin + code_for_op = {**CStyleLanguage().code_for_op, UnaryOps.SIN: lambda x,dtype: f"precise::sin({x})"} + + # upcast to float32 all the ops that don't support bfloat16 + extra_matcher = PatternMatcher([ + # NOTE: this is copied from PTX + *[(UPat(UOps.ALU, arg=op, dtype=dtypes.bfloat16, name="x"), + lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))) + for op in [BinaryOps.MAX, UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]] + ]) + extra_pm + + string_rewrite = PatternMatcher([ + (UPat(UOps.BITCAST, name="x"), lambda r,x: f"as_type<{r.render_dtype(x.dtype)}>({r[x.src[0]]})"), + ]) + base_rewrite def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): prefix, wmma_args = ["#include ","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA]) diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index ed46da41c8..65b58684a8 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -294,7 +294,8 @@ class QCOMAllocator(HCQAllocator): pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0 pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add - texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True) + if options.external_ptr: texture = QCOMBuffer(options.external_ptr, size) + else: texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True) # Extend HCQBuffer with texture-related info. texture.pitch, texture.real_stride, texture.desc, texture.ibo = pitch, real_stride, [0] * 16, [0] * 16 @@ -308,7 +309,7 @@ class QCOMAllocator(HCQAllocator): return texture - return self.device._gpu_alloc(size, map_to_cpu=True) + return QCOMBuffer(options.external_ptr, size) if options.external_ptr else self.device._gpu_alloc(size, map_to_cpu=True) def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, dest_off=0, src_off=0): while src_off < src_size: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 603a8e352d..c56bb5bddf 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -398,6 +398,21 @@ class Tensor: """ return Tensor._metaop(MetaOps.EMPTY, argfix(*shape), **kwargs) + @staticmethod + def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor: + """ + Exposes the pointer as a Tensor without taking ownership of the original data. + The pointer must remain valid for the entire lifetime of the created Tensor. + + You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. + Additionally, all other keyword arguments are passed to the constructor of the tensor. + """ + + r = Tensor._metaop(MetaOps.EMPTY, shape, **kwargs) + r.lazydata.buffer.allocate(external_ptr=ptr) + del r.lazydata.srcs # fake realize + return r + _seed: int = int(time.time()) _device_seeds: Dict[str, int] = {} _device_rng_counters: Dict[str, Tensor] = {}