don't make dtype

This commit is contained in:
George Hotz
2026-01-04 13:49:47 -08:00
parent 058816dd92
commit 10e2c47d52
2 changed files with 22 additions and 22 deletions

View File

@@ -11,11 +11,13 @@ _QDTYPES: dict[str, DType] = {
'f64': dtypes.float64, 'f32': dtypes.float32, 'f16': dtypes.float16, 'bf16': dtypes.bfloat16,
'u64': dtypes.uint64, 'u32': dtypes.uint32, 'u16': dtypes.uint16, 'u8': dtypes.uint8,
'i64': dtypes.int64, 'i32': dtypes.int32, 'i16': dtypes.int16, 'i8': dtypes.int8,
'u24': DType.new(6, 24, "uint24", None), 'i24': DType.new(5, 24, "int24", None),
'u4': DType.new(6, 4, "uint4", None), 'i4': DType.new(5, 4, "int4", None),
'u3': DType.new(6, 3, "uint3", None), 'u1': DType.new(6, 1, "uint1", None), 'i1': DType.new(5, 1, "int1", None),
'b128': DType.new(8, 128, "bits128", None), 'b64': dtypes.uint64, 'b32': dtypes.uint32,
'b16': dtypes.uint16, 'b8': dtypes.uint8,
'b128': DType.new(8, 128, "bits128", None), 'b65': DType.new(8, 65, "bits65", None),
'b64': dtypes.uint64, 'b32': dtypes.uint32, 'b16': dtypes.uint16, 'b8': dtypes.uint8,
'b1201': DType.new(8, 1201, "bits1201", None),
'u65': DType.new(6, 65, "uint65", None), 'u24': DType.new(6, 24, "uint24", None),
'u6': DType.new(6, 6, "uint6", None), 'u4': DType.new(6, 4, "uint4", None),
'u3': DType.new(6, 3, "uint3", None), 'u1': DType.new(6, 1, "uint1", None),
'i65': DType.new(5, 65, "int65", None), 'i24': DType.new(5, 24, "int24", None), 'i1': DType.new(5, 1, "int1", None),
'u': dtypes.uint32, 'i': dtypes.int32, 'f': dtypes.float32,
}
@@ -33,7 +35,7 @@ _UNOPS: dict[str, Ops] = {'-': Ops.NEG, '~': Ops.XOR, '!': Ops.CMPEQ}
@dataclass(frozen=True)
class Assign: lhs: UOp; rhs: UOp
@dataclass(frozen=True)
class Declare: name: str; dtype: str
class Declare: name: str; dtype: DType
@dataclass(frozen=True)
class If: branches: tuple[tuple[UOp|None, tuple[Stmt, ...]], ...]
@dataclass(frozen=True)
@@ -76,14 +78,6 @@ def _fop(s, ops):
def _get_dtype(name: str) -> DType | None: return _QDTYPES.get(name.lower())
def _make_dtype(bits: int, typ: str) -> DType:
key = f"{typ.lower()}{bits}"
if key in _QDTYPES: return _QDTYPES[key]
if typ == 'F': return DType.new(13, bits, f"float{bits}", None)
if typ == 'I': return DType.new(5, bits, f"int{bits}", None)
if typ == 'U' or typ == 'B': return DType.new(6, bits, f"uint{bits}", None)
return dtypes.uint32
def expr(s: str) -> UOp:
s = s.strip().rstrip(';')
if s.endswith('.') and not (len(s) > 1 and s[-2].isdigit()): s = s[:-1]
@@ -96,14 +90,14 @@ def expr(s: str) -> UOp:
if s[0] == '{' and s[-1] == '}': return UOp(Ops.CAT, dtypes.void, tuple(expr(a) for a in _split(s[1:-1])))
# Typed cast: 32'U(expr)
if m := re.match(r"^(\d+)'([IUFB])\(", s):
if (e := _match(s, m.end()-1, '(', ')')) == len(s)-1: return UOp(Ops.CAST, _make_dtype(int(m[1]), m[2]), (expr(s[m.end():e]),))
if (e := _match(s, m.end()-1, '(', ')')) == len(s)-1: return UOp(Ops.CAST, _QDTYPES[f"{m[2].lower()}{m[1]}"], (expr(s[m.end():e]),))
# Typed constant: 32'-5I
if m := re.match(r"^(\d+)'(-?\d+)([IUFB])?$", s):
return UOp(Ops.CONST, _make_dtype(int(m[1]), m[3] or 'I'), arg=int(m[2]))
return UOp(Ops.CONST, _QDTYPES[f"{(m[3] or 'I').lower()}{m[1]}"], arg=int(m[2]))
if m := re.match(r"^(\d+)'(-?[\d.]+)$", s):
return UOp(Ops.CONST, _make_dtype(int(m[1]), 'F'), arg=float(m[2]))
return UOp(Ops.CONST, _QDTYPES[f"f{m[1]}"], arg=float(m[2]))
if m := re.match(r"^(\d+)'(0x[0-9a-fA-F]+)$", s):
return UOp(Ops.CONST, _make_dtype(int(m[1]), 'U'), arg=int(m[2], 16))
return UOp(Ops.CONST, _QDTYPES[f"u{m[1]}"], arg=int(m[2], 16))
# Function call -> CUSTOM
if m := re.match(r"^([A-Za-z_]\w*)\(", s):
if (e := _match(s, m.end()-1, '(', ')')) == len(s)-1:
@@ -175,7 +169,11 @@ def stmt(line: str) -> Stmt|None:
line = line.split('//')[0].strip().rstrip(';')
if not line: return None
if line[:5] == 'eval ': return Assign(UOp(Ops.DEFINE_VAR, dtypes.void, arg=('_eval', None, None)), UOp(Ops.DEFINE_VAR, dtypes.void, arg=(line, None, None)))
if line[:8] == 'declare ' and ':' in line: n, t = line[8:].split(':', 1); return Declare(n.strip(), t.strip())
if line[:8] == 'declare ' and ':' in line:
n, t = line[8:].split(':', 1)
t = t.strip().split('[')[0] # strip array suffix like [64]
if m := re.match(r"^(\d+)'([IUFB])$", t): return Declare(n.strip(), _QDTYPES[f"{m[2].lower()}{m[1]}"])
return None # unsupported declare type
for op, uop in [('+=', Ops.ADD), ('-=', Ops.SUB), ('|=', Ops.OR), ('&=', Ops.AND), ('^=', Ops.XOR), ('<<=', Ops.SHL), ('>>=', Ops.SHR)]:
if op in line:
l, r = line.split(op, 1)
@@ -238,7 +236,7 @@ if __name__ == "__main__":
case UOp(Ops.CUSTOM, _, args, name): return f"{name}({', '.join(pr(x) for x in args)})"
case UOp(Ops.CAT, _, exprs): return f"{{{', '.join(pr(x) for x in exprs)}}}"
case Assign(l, r): return f"{p}{pr(l)} = {pr(r)}"
case Declare(n, t): return f"{p}declare {n}: {t}"
case Declare(n, t): return f"{p}declare {n}: {t.name}"
case If(br): return f"{p}if " + " elif ".join(f"({pr(c) if c is not None else 'else'}) {{\n" + "\n".join(pr(s,d+1) for s in b) + f"\n{p}}}" for c,b in br)
case For(v,s,e,b): return f"{p}for {v} in {pr(s)}:{pr(e)} {{\n" + "\n".join(pr(x,d+1) for x in b) + f"\n{p}}}"
case tuple(): return "\n".join(pr(x, d) for x in n)

View File

@@ -45,6 +45,7 @@ LDS_BUF = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(addrspace=AddrSpace.LOCAL), arg
class Ctx:
def __init__(self, mem_buf: UOp = MEM_BUF):
self.vars: dict[str, UOp] = dict(INPUT_VARS)
self.decls: dict[str, DType] = {}
self.outputs: list[tuple[str, UOp, DType]] = []
self.mem_stores: list[UOp] = []
self.mem_buf = mem_buf
@@ -354,7 +355,7 @@ def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, s
def _stmt(stmt, ctx: Ctx):
match stmt:
case Declare(_, _): pass
case Declare(name, dtype): ctx.decls[name] = dtype
case Assign(lhs, rhs):
# Handle MEM[addr].type = value -> memory store
if lhs.op == Ops.BITCAST and lhs.src[0].op == Ops.CUSTOM and lhs.src[0].arg == 'MEM':
@@ -434,8 +435,9 @@ def _transform_if(branches: tuple, ctx: Ctx):
def _transform_for(var: str, start: UOp, end: UOp, body: tuple, ctx: Ctx):
start_val = start.arg if start.op == Ops.CONST else int(_expr(start, ctx).arg)
end_val = end.arg if end.op == Ops.CONST else int(_expr(end, ctx).arg)
var_dtype = ctx.decls.get(var, dtypes.uint32)
for i in range(int(end_val), int(start_val) - 1, -1):
ctx.vars[var] = UOp.const(dtypes.uint32, i)
ctx.vars[var] = UOp.const(var_dtype, i)
for s in body:
if isinstance(s, If): _transform_if(s.branches, ctx)
elif isinstance(s, Assign): _stmt(s, ctx)