mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
speedups
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
import struct, math, re
|
||||
from enum import IntEnum
|
||||
from functools import cache, cached_property
|
||||
from functools import cache
|
||||
from typing import overload, Annotated, TypeVar, Generic
|
||||
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
|
||||
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
|
||||
@@ -346,6 +346,7 @@ class Inst:
|
||||
if 'abs_' in kwargs: kwargs['abs'] = kwargs.pop('abs_')
|
||||
orig_args = dict(zip(field_names, args)) | kwargs
|
||||
self._values.update(orig_args)
|
||||
self._precompute()
|
||||
self._validate(orig_args)
|
||||
# Pre-shift literal for 64-bit sources (literal param is always raw 32-bit value from user)
|
||||
if literal is not None:
|
||||
@@ -386,6 +387,7 @@ class Inst:
|
||||
elif name == 'sbase': self._values[name] = (val.idx if isinstance(val, Reg) else val.val if isinstance(val, SrcMod) else val * 2) // 2
|
||||
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = val.idx // 4
|
||||
elif marker is _VDSTYEnc and isinstance(val, VGPR): self._values[name] = val.idx >> 1
|
||||
self._precompute_fields()
|
||||
|
||||
def _encode_field(self, name: str, val) -> int:
|
||||
if isinstance(val, RawImm): return val.val
|
||||
@@ -450,6 +452,8 @@ class Inst:
|
||||
inst = object.__new__(cls)
|
||||
inst._values = {n: RawImm(v) if n in SRC_FIELDS else v for n, bf in cls._fields.items() if n != 'encoding' for v in [(word >> bf.lo) & bf.mask()]}
|
||||
inst._literal = None
|
||||
inst._precompute()
|
||||
inst._precompute_fields()
|
||||
return inst
|
||||
|
||||
@classmethod
|
||||
@@ -510,25 +514,32 @@ class Inst:
|
||||
'VOPD': VOPDOp, 'VINTERP': VINTERPOp}
|
||||
_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
"""Return the op as an enum (e.g., VOP1Op.V_MOV_B32). VOP3 returns VOPCOp/VOP3SDOp for those op ranges."""
|
||||
def _precompute(self):
|
||||
"""Precompute op, op_name, _spec_regs, _spec_dtype for fast access."""
|
||||
val = self._values.get('op')
|
||||
if val is None: return None
|
||||
if hasattr(val, 'name'): return val # already an enum
|
||||
cls_name = self.__class__.__name__
|
||||
assert cls_name in self._enum_map, f"no enum map for {cls_name}"
|
||||
return self._enum_map[cls_name](val)
|
||||
if val is None: self.op = None
|
||||
elif hasattr(val, 'name'): self.op = val
|
||||
else:
|
||||
cls_name = self.__class__.__name__
|
||||
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
|
||||
if cls_name == 'VOP3':
|
||||
try:
|
||||
if val < 256: self.op = VOPCOp(val)
|
||||
elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val)
|
||||
else: self.op = VOP3Op(val)
|
||||
except ValueError: self.op = val
|
||||
elif cls_name in self._enum_map:
|
||||
try: self.op = self._enum_map[cls_name](val)
|
||||
except ValueError: self.op = val
|
||||
else: self.op = val
|
||||
self.op_name = self.op.name if hasattr(self.op, 'name') else ''
|
||||
self._spec_regs = spec_regs(self.op_name)
|
||||
self._spec_dtype = spec_dtype(self.op_name)
|
||||
|
||||
@cached_property
|
||||
def op_name(self) -> str:
|
||||
op = self.op
|
||||
return op.name if hasattr(op, 'name') else ''
|
||||
|
||||
@cached_property
|
||||
def _spec_regs(self) -> tuple[int, int, int, int]: return spec_regs(self.op_name)
|
||||
@cached_property
|
||||
def _spec_dtype(self) -> tuple[str | None, str | None, str | None, str | None]: return spec_dtype(self.op_name)
|
||||
def _precompute_fields(self):
|
||||
"""Unwrap all field values as direct attributes for fast access."""
|
||||
for name, val in self._values.items():
|
||||
if name != 'op': setattr(self, name, unwrap(val))
|
||||
def dst_regs(self) -> int: return self._spec_regs[0]
|
||||
def src_regs(self, n: int) -> int: return self._spec_regs[n + 1]
|
||||
def num_srcs(self) -> int: return spec_num_srcs(self.op_name)
|
||||
|
||||
@@ -179,9 +179,8 @@ def decode_program(data: bytes) -> Program:
|
||||
base_size = inst_class._size()
|
||||
# Pass enough data for potential 64-bit literal (base + 8 bytes max)
|
||||
inst = inst_class.from_bytes(data[i:i+base_size+8])
|
||||
for name, val in inst._values.items():
|
||||
if name != 'op': setattr(inst, name, unwrap(val)) # skip op to preserve property access
|
||||
inst._words = inst.size() // 4
|
||||
inst._fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op)
|
||||
result[i // 4] = inst
|
||||
i += inst._words * 4
|
||||
return result
|
||||
@@ -201,10 +200,10 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
elif isinstance(inst, SOPP): ssrc0, sdst = None, None
|
||||
else: raise NotImplementedError(f"Unknown scalar type {type(inst)}")
|
||||
|
||||
fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op)
|
||||
fn = inst._fn
|
||||
if fn is None:
|
||||
if isinstance(inst, SOPP): return 0 # SOPP without pseudocode (waits, hints, nops) are no-ops
|
||||
raise NotImplementedError(f"{inst.op.name} not in pseudocode")
|
||||
raise NotImplementedError(f"{inst.op_name} not in pseudocode")
|
||||
|
||||
# SMEM: memory loads
|
||||
if isinstance(inst, SMEM):
|
||||
@@ -250,26 +249,24 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None)
|
||||
for vopd_op, s0, s1, d0, dst in inputs: V[dst] = _exec_vopd(vopd_op, s0, s1, d0, st, lane)
|
||||
return
|
||||
|
||||
# Lookup compiled function for this op (V_NOP has no pcode)
|
||||
if isinstance(inst, VOP1) and inst.op == VOP1Op.V_NOP: return
|
||||
fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op)
|
||||
fn = inst._fn
|
||||
if fn is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
|
||||
|
||||
# Memory ops (FLAT/GLOBAL/SCRATCH and DS)
|
||||
if isinstance(inst, (FLAT, DS)):
|
||||
ndwords = _op_ndwords(inst.op.name)
|
||||
ndwords = _op_ndwords(inst.op_name)
|
||||
if isinstance(inst, FLAT):
|
||||
addr = V[inst.addr] | (V[inst.addr + 1] << 32)
|
||||
ADDR = (st.rsgpr64(inst.saddr) + V[inst.addr] + _sext(inst.offset, 13)) & MASK64 if inst.saddr not in (NULL, 0x7f) else (addr + _sext(inst.offset, 13)) & MASK64
|
||||
vdata_src = inst.vdst if 'LOAD' in inst.op.name else inst.data
|
||||
vdata_src = inst.vdst if 'LOAD' in inst.op_name else inst.data
|
||||
result = fn(GlobalMem, ADDR, _vgpr_read(V, vdata_src, ndwords), V[inst.vdst])
|
||||
if 'VDATA' in result: _vgpr_write(V, inst.vdst, result['VDATA'], ndwords)
|
||||
if 'RETURN_DATA' in result: _vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords)
|
||||
else: # DS
|
||||
data0, data1 = _vgpr_read(V, inst.data0, ndwords), _vgpr_read(V, inst.data1, ndwords) if inst.data1 is not None else 0
|
||||
result = fn(lds, V[inst.addr], data0, data1, inst.offset0, inst.offset1)
|
||||
if 'RETURN_DATA' in result and ('_RTN' in inst.op.name or '_LOAD' in inst.op.name):
|
||||
_vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords * 2 if '_2ADDR_' in inst.op.name else ndwords)
|
||||
if 'RETURN_DATA' in result and ('_RTN' in inst.op_name or '_LOAD' in inst.op_name):
|
||||
_vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords * 2 if '_2ADDR_' in inst.op_name else ndwords)
|
||||
return
|
||||
|
||||
# VOP3SD: has extra scalar dest for carry output
|
||||
@@ -389,7 +386,8 @@ def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int
|
||||
st.pc += inst_words + exec_scalar(st, inst)
|
||||
return 0
|
||||
# Wave-level vector ops: execute once for entire wave (not per-lane)
|
||||
if isinstance(inst, VOP3P) and 'WMMA' in inst.op_name:
|
||||
if isinstance(inst, VOP1) and inst.op == VOP1Op.V_NOP: pass
|
||||
elif isinstance(inst, VOP3P) and 'WMMA' in inst.op_name:
|
||||
exec_wmma(st, inst, inst.op)
|
||||
elif isinstance(inst, VOP3) and inst.op == VOP3Op.V_WRITELANE_B32:
|
||||
wr_lane = st.rsrc(inst.src1, 0) & 0x1f
|
||||
|
||||
Reference in New Issue
Block a user