mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
better
This commit is contained in:
@@ -11,11 +11,14 @@ from tinygrad.helpers import DEBUG
|
||||
from tinygrad.dtype import INVERSE_DTYPES_DICT
|
||||
_QDTYPES: dict[str, DType] = {
|
||||
'f64': dtypes.float64, 'f32': dtypes.float32, 'f16': dtypes.float16, 'bf16': dtypes.bfloat16,
|
||||
'fp8': DType.new(4, 8, "fp8", None), 'bf8': DType.new(4, 8, "bf8", None), 'fp6': DType.new(4, 6, "fp6", None), 'bf6': DType.new(4, 6, "bf6", None),
|
||||
'fp4': DType.new(4, 4, "fp4", None), 'i4': DType.new(5, 4, "i4", None),
|
||||
'u64': dtypes.uint64, 'u32': dtypes.uint32, 'u16': dtypes.uint16, 'u8': dtypes.uint8,
|
||||
'i64': dtypes.int64, 'i32': dtypes.int32, 'i16': dtypes.int16, 'i8': dtypes.int8,
|
||||
'b1201': DType.new(6, 1201, "b1201", None), 'b128': DType.new(6, 128, "b128", None),
|
||||
'b65': DType.new(6, 65, "b65", None), 'b64': dtypes.uint64, 'b32': dtypes.uint32, 'b16': dtypes.uint16, 'b8': dtypes.uint8,
|
||||
'u1201': DType.new(6, 1201, "u1201", None), 'u65': DType.new(6, 65, "u65", None), 'u24': DType.new(6, 24, "u24", None),
|
||||
'b1201': DType.new(6, 1201, "b1201", None), 'b1024': DType.new(6, 1024, "b1024", None), 'b512': DType.new(6, 512, "b512", None),
|
||||
'b192': DType.new(6, 192, "b192", None), 'b128': DType.new(6, 128, "b128", None),
|
||||
'b65': DType.new(6, 65, "b65", None), 'b64': dtypes.uint64, 'b32': dtypes.uint32, 'b23': DType.new(6, 23, "b23", None), 'b16': dtypes.uint16, 'b8': dtypes.uint8, 'b4': DType.new(6, 4, "b4", None),
|
||||
'u1201': DType.new(6, 1201, "u1201", None), 'u65': DType.new(6, 65, "u65", None), 'u24': DType.new(6, 24, "u24", None), 'u23': DType.new(6, 23, "u23", None),
|
||||
'u6': DType.new(6, 6, "u6", None), 'u4': DType.new(6, 4, "u4", None),
|
||||
'u3': DType.new(6, 3, "u3", None), 'u1': DType.new(6, 1, "u1", None),
|
||||
'i65': DType.new(5, 65, "i65", None), 'i24': DType.new(5, 24, "i24", None), 'i1': DType.new(5, 1, "i1", None),
|
||||
@@ -208,9 +211,25 @@ def stmt(line: str) -> Stmt|None:
|
||||
lhs = expr(l)
|
||||
return Assign(lhs, UOp(uop, dtypes.void, (lhs, expr(r))))
|
||||
if '=' in line and not any(line[:k] == p for k, p in [(3,'if '),(6,'elsif '),(4,'for ')]):
|
||||
eq = line.index('=')
|
||||
if eq > 0 and line[eq-1] not in '!<>=' and eq < len(line)-1 and line[eq+1] != '=':
|
||||
return Assign(expr(line[:eq]), expr(line[eq+1:]))
|
||||
# Find leftmost assignment = (not ==, <=, >=, !=) for chained assignment support
|
||||
eq = -1
|
||||
for i in range(1, len(line) - 1):
|
||||
if line[i] == '=' and line[i-1] not in '!<>=' and line[i+1] != '=':
|
||||
eq = i
|
||||
break
|
||||
if eq > 0:
|
||||
rhs = line[eq+1:].strip()
|
||||
# Check if RHS contains another assignment = (not ==, <=, >=, !=)
|
||||
has_assign = False
|
||||
for i in range(1, len(rhs) - 1):
|
||||
if rhs[i] == '=' and rhs[i-1] not in '!<>=' and rhs[i+1] != '=':
|
||||
has_assign = True
|
||||
break
|
||||
if has_assign:
|
||||
rhs_parsed = stmt(rhs)
|
||||
if isinstance(rhs_parsed, Assign):
|
||||
return Assign(expr(line[:eq]), rhs_parsed)
|
||||
return Assign(expr(line[:eq]), expr(rhs))
|
||||
# Bare function call (e.g., nop())
|
||||
if re.match(r'\w+\([^)]*\)$', line):
|
||||
return expr(line)
|
||||
|
||||
@@ -57,9 +57,11 @@ def _pr(n, d=0):
|
||||
case Assign(l, r):
|
||||
compound = {Ops.ADD: '+=', Ops.SUB: '-=', Ops.OR: '|=', Ops.AND: '&=', Ops.XOR: '^=', Ops.SHL: '<<=', Ops.SHR: '>>='}
|
||||
is_pc = l.op == Ops.DEFINE_VAR and l.arg[0] == 'PC'
|
||||
if r.op in compound and len(r.src) == 2 and r.src[0] == l and not is_pc:
|
||||
if isinstance(r, UOp) and r.op in compound and len(r.src) == 2 and r.src[0] == l and not is_pc:
|
||||
return f"{p}{_pr(l)} {compound[r.op]} {_pr(r.src[1])}"
|
||||
return f"{p}{_pr(l)} = {_pr(r)}"
|
||||
# Chained assignment: render without prefix for RHS
|
||||
rhs = _pr(r) if isinstance(r, Assign) else _pr(r)
|
||||
return f"{p}{_pr(l)} = {rhs}"
|
||||
case Declare(name, dt):
|
||||
base = dt.scalar() if dt.count > 1 else dt
|
||||
suffix = f"[{dt.count}]" if dt.count > 1 else ""
|
||||
@@ -104,7 +106,8 @@ def _norm(s, keep_structure=False):
|
||||
return s.strip()
|
||||
|
||||
def _test_arch(test, pcode_strings, min_parse=98, min_roundtrip=98):
|
||||
ok, fail, match, errs = 0, 0, 0, {}
|
||||
ok, fail, match = 0, 0, 0
|
||||
errs: dict[str, list[str]] = {}
|
||||
for cls, ops in pcode_strings.items():
|
||||
for op, pc in ops.items():
|
||||
try:
|
||||
@@ -112,7 +115,9 @@ def _test_arch(test, pcode_strings, min_parse=98, min_roundtrip=98):
|
||||
ok += 1
|
||||
except Exception as e:
|
||||
fail += 1
|
||||
errs[str(e)[:60]] = errs.get(str(e)[:60], 0) + 1
|
||||
key = str(e)[:60]
|
||||
if key not in errs: errs[key] = []
|
||||
errs[key].append(f"{cls.__name__}.{op.name}")
|
||||
continue
|
||||
rendered = _pr(ast)
|
||||
if _norm(pc) == _norm(rendered):
|
||||
@@ -136,14 +141,15 @@ def _test_arch(test, pcode_strings, min_parse=98, min_roundtrip=98):
|
||||
roundtrip_rate = 100 * match / ok if ok > 0 else 0
|
||||
if DEBUG:
|
||||
print(f"Parsed: {ok}/{total} ({parse_rate:.1f}%), Match: {match}/{ok} ({roundtrip_rate:.1f}%)")
|
||||
for e, c in sorted(errs.items(), key=lambda x: -x[1])[:10]: print(f" {c}: {e}")
|
||||
for e, ops in sorted(errs.items(), key=lambda x: -len(x[1])):
|
||||
print(f" {len(ops)}: {e} ({ops[0]})")
|
||||
test.assertGreater(parse_rate, min_parse, f"Parse rate {parse_rate:.1f}% should be >{min_parse}%")
|
||||
test.assertGreater(roundtrip_rate, min_roundtrip, f"Roundtrip rate {roundtrip_rate:.1f}% should be >{min_roundtrip}%")
|
||||
|
||||
class TestQcodeParseAndRoundtrip(unittest.TestCase):
|
||||
def test_rdna3(self): _test_arch(self, RDNA3_PCODE)
|
||||
def test_rdna4(self): _test_arch(self, RDNA4_PCODE, min_parse=96)
|
||||
def test_cdna(self): _test_arch(self, CDNA_PCODE, min_parse=78)
|
||||
def test_cdna(self): _test_arch(self, CDNA_PCODE, min_parse=95)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user