bitcast cleanup [pr] (#6933)

This commit is contained in:
George Hotz
2024-10-07 19:16:16 +08:00
committed by GitHub
parent 0cf815a93a
commit f7f94cd62f
3 changed files with 5 additions and 3 deletions

View File

@@ -21,7 +21,7 @@ class _Device:
cpn = multiprocessing.current_process().name
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}"
x = ix.split(":")[0].upper()
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
return ret
@property

View File

@@ -93,6 +93,7 @@ class LazyBuffer(MathTrait):
self.base.forced_realize = True
return self
def bitcast(self, dtype:DType) -> LazyBuffer: return self.cast(dtype, bitcast=True)
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:
if self.dtype == dtype: return self
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")

View File

@@ -19,9 +19,10 @@ class ContiguousBackward(Function):
class Cast(Function):
def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
self.input_dtype, self.bitcast = x.dtype, bitcast
return x.cast(dtype, bitcast)
return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.bitcast(self.input_dtype) if self.bitcast else grad_output.cast(self.input_dtype)
# ************* unary ops *************