From f7f94cd62f3504873abac23ef334f14797e27ead Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 7 Oct 2024 19:16:16 +0800 Subject: [PATCH] bitcast cleanup [pr] (#6933) --- tinygrad/device.py | 2 +- tinygrad/engine/lazy.py | 1 + tinygrad/function.py | 5 +++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tinygrad/device.py b/tinygrad/device.py index 59425efdb7..27e9da8242 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -21,7 +21,7 @@ class _Device: cpn = multiprocessing.current_process().name assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}" x = ix.split(":")[0].upper() - ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501 + ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501 if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}") return ret @property diff --git a/tinygrad/engine/lazy.py b/tinygrad/engine/lazy.py index b7c954c9e4..6a9a155709 100644 --- a/tinygrad/engine/lazy.py +++ b/tinygrad/engine/lazy.py @@ -93,6 +93,7 @@ class LazyBuffer(MathTrait): self.base.forced_realize = True return self + def bitcast(self, dtype:DType) -> LazyBuffer: return self.cast(dtype, bitcast=True) def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer: if self.dtype == dtype: return self if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)") diff --git a/tinygrad/function.py b/tinygrad/function.py index e4540eae60..c31ba0a1f5 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -19,9 +19,10 @@ class ContiguousBackward(Function): class Cast(Function): def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer: self.input_dtype, self.bitcast = x.dtype, bitcast - return x.cast(dtype, bitcast) + return x.bitcast(dtype) if self.bitcast else x.cast(dtype) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: + return grad_output.bitcast(self.input_dtype) if self.bitcast else grad_output.cast(self.input_dtype) # ************* unary ops *************