mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Revert "add vcount to PtrDtype (#7388)"
This reverts commit b086584d64.
This commit is contained in:
@@ -15,14 +15,12 @@ class DType:
|
||||
count: int
|
||||
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
|
||||
def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
|
||||
@property
|
||||
def vcount(self): return self.count
|
||||
def vec(self, sz:int):
|
||||
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
|
||||
def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]:
|
||||
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local, v)
|
||||
def ptr(self, local=False) -> Union[PtrDType, ImageDType]:
|
||||
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, local)
|
||||
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -32,25 +30,18 @@ class ImageDType(DType):
|
||||
local: bool = False # images are never local
|
||||
def scalar(self) -> DType: return self.base
|
||||
def vec(self, sz:int): return self.base.vec(sz)
|
||||
def ptr(self, local=False, v=1) -> Union[PtrDType, ImageDType]: return self
|
||||
def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return self
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PtrDType(DType):
|
||||
base: DType
|
||||
local: bool
|
||||
v: int
|
||||
def __hash__(self): return super().__hash__()
|
||||
def scalar(self) -> DType: return self.ptr(self.local, 1)
|
||||
def vec(self, sz:int) -> DType: return self.ptr(self.local, sz)
|
||||
@property
|
||||
def vcount(self): return self.v
|
||||
# local isn't used in the compare
|
||||
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
|
||||
def __ne__(self, dt): return not (self == dt)
|
||||
def __repr__(self):
|
||||
arg = (["local=true"] if self.local else []) + ([f"v={self.v}"] if self.v != 1 else [])
|
||||
return f"{self.base.__repr__()}.ptr({','.join(arg)})"
|
||||
def __repr__(self): return f"{super().__repr__()}.ptr(local=True)" if self.local else f"{super().__repr__()}.ptr()"
|
||||
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
|
||||
@@ -318,8 +318,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is UOps.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
|
||||
if self.op is UOps.CONST: return UOp.const(self.dtype.scalar(), self.arg)
|
||||
i = (i,)
|
||||
if self.dtype == dtypes.void or (i == tuple(range(len(i))) and self.dtype.vcount == len(i)): return self
|
||||
assert len(i) >= 1 and all(x < self.dtype.vcount for x in i), f"bad GEP on {self.dtype}, {i}"
|
||||
if self.dtype == dtypes.void or (i == tuple(range(len(i))) and self.dtype.count == len(i)): return self
|
||||
assert len(i) >= 1 and all(x < self.dtype.count for x in i), f"bad GEP on {self.dtype}, {i}"
|
||||
return UOp(UOps.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||
@staticmethod
|
||||
def load(*src:UOp, dtype:DType): return UOp(UOps.LOAD, dtype, src)
|
||||
|
||||
@@ -71,9 +71,7 @@ def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
|
||||
replaces[base] = ret = base.replace(src=tuple(_replace_uop(x, replaces) for x in base.src))
|
||||
return ret
|
||||
@functools.lru_cache(None)
|
||||
def _prg(k:Optional[Kernel]) -> Optional[str]:
|
||||
try: return k.to_program().src if isinstance(k, Kernel) else None
|
||||
except Exception: return None
|
||||
def _prg(k:Optional[Kernel]) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
|
||||
def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
|
||||
g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=pcall(_prg, k))
|
||||
replaces: Dict[UOp, UOp] = {}
|
||||
|
||||
Reference in New Issue
Block a user