mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
@@ -273,7 +273,7 @@ def train_resnet():
|
||||
else:
|
||||
it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False, pad_first_batch=True), total=steps_in_val_epoch))
|
||||
i, proc = 0, data_get(it)
|
||||
|
||||
|
||||
prev_cookies = []
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
@@ -692,7 +692,7 @@ def train_unet3d():
|
||||
loss.backward()
|
||||
optim.step()
|
||||
return loss.realize()
|
||||
|
||||
|
||||
@Tensor.train(mode=False)
|
||||
@Tensor.test()
|
||||
def eval_step(model, x, y):
|
||||
@@ -701,7 +701,7 @@ def train_unet3d():
|
||||
loss = dice_ce_loss(y_hat, y)
|
||||
score = dice_score(y_hat, y)
|
||||
return loss.realize(), score.realize()
|
||||
|
||||
|
||||
if WANDB: wandb.init(config=config, project=PROJ_NAME)
|
||||
|
||||
step_times, start_epoch = [], 1
|
||||
@@ -710,7 +710,7 @@ def train_unet3d():
|
||||
next_eval_at = start_eval_at
|
||||
|
||||
print(f"Training on {GPUS}")
|
||||
|
||||
|
||||
if BENCHMARK: print("Benchmarking UNet3D")
|
||||
else: print(f"Start evaluation at epoch {start_eval_at} and every {evaluate_every} epoch(s) afterwards")
|
||||
|
||||
@@ -821,7 +821,8 @@ def train_rnnt():
|
||||
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
|
||||
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
|
||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||
t.shard_(GPUS, axis=0)
|
||||
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
|
||||
else: t.to_(GPUS[0])
|
||||
optimizer.zero_grad()
|
||||
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
@@ -829,7 +830,7 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
|
||||
(loss * loss_scaler).backward()
|
||||
|
||||
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
|
||||
for p in optimizer.params:
|
||||
for p in optimizer.params:
|
||||
p.grad = p.grad / loss_scaler
|
||||
global_norm += p.grad.float().square().sum()
|
||||
global_norm = global_norm.sqrt()
|
||||
@@ -843,7 +844,8 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
|
||||
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
|
||||
masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
|
||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||
t.shard_(GPUS, axis=0)
|
||||
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
|
||||
else: t.to_(GPUS[0])
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \
|
||||
model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
@@ -942,8 +944,9 @@ def train_bert():
|
||||
p = p.assign(Tensor.zeros_like(p).contiguous()).realize()
|
||||
|
||||
parameters = get_parameters(model)
|
||||
for p in parameters:
|
||||
p.to_(GPUS)
|
||||
if len(GPUS) > 1:
|
||||
for p in parameters:
|
||||
p.to_(GPUS)
|
||||
|
||||
# ** Log run config **
|
||||
for key, value in config.items(): print(f'HParam: "{key}": {value}')
|
||||
@@ -1061,7 +1064,7 @@ def train_bert():
|
||||
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK) or i == train_steps:
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": i*BS, "step_num": i})
|
||||
if getenv("RESET_STEP", 0) or INITMLPERF: train_step_bert.reset()
|
||||
if getenv("RESET_STEP", 0): train_step_bert.reset()
|
||||
else: train_step_bert.captured.free_intermediates()
|
||||
eval_lm_losses = []
|
||||
eval_clsf_losses = []
|
||||
|
||||
@@ -8,6 +8,7 @@ export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
export RESET_STEP=1
|
||||
export BENCHMARK=10 DEBUG=2
|
||||
|
||||
python3 examples/mlperf/model_train.py
|
||||
|
||||
@@ -17,7 +17,7 @@ DATETIME=$(date "+%m%d%H%M")
|
||||
LOGFILE="bert_green_${DATETIME}_${SEED}.log"
|
||||
|
||||
# init
|
||||
BENCHMARK=10 INITMLPERF=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
|
||||
BENCHMARK=10 INITMLPERF=1 RESET_STEP=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
|
||||
|
||||
# run
|
||||
PARALLEL=0 RUNMLPERF=1 python3 examples/mlperf/model_train.py | tee -a $LOGFILE
|
||||
|
||||
@@ -8,6 +8,7 @@ export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
export RESET_STEP=1
|
||||
export BENCHMARK=10 DEBUG=2
|
||||
|
||||
python3 examples/mlperf/model_train.py
|
||||
|
||||
@@ -17,7 +17,7 @@ DATETIME=$(date "+%m%d%H%M")
|
||||
LOGFILE="bert_red_${DATETIME}_${SEED}.log"
|
||||
|
||||
# init
|
||||
BENCHMARK=10 INITMLPERF=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
|
||||
BENCHMARK=10 INITMLPERF=1 RESET_STEP=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
|
||||
|
||||
# run
|
||||
PARALLEL=0 RUNMLPERF=1 python3 examples/mlperf/model_train.py | tee -a $LOGFILE
|
||||
|
||||
@@ -59,7 +59,7 @@ class NVDriver(VirtDriver):
|
||||
self.root_handle = None
|
||||
|
||||
self.gpus = {}
|
||||
self.next_fd = (1 << 30)
|
||||
self.next_fd = (1 << 29)
|
||||
self.next_handle = 1
|
||||
|
||||
self.object_by_handle = {}
|
||||
|
||||
@@ -499,5 +499,17 @@ class TestHCQ(unittest.TestCase):
|
||||
assert "0xDEADBEE1" in str(ctx.exception)
|
||||
os.environ.pop("MOCKGPU_EMU_FAULTADDR")
|
||||
|
||||
def test_multidevice(self):
|
||||
try: amd_dev = Device["AMD"]
|
||||
except Exception: self.skipTest("no AMD device, test skipped")
|
||||
|
||||
try: nv_dev = Device["NV"]
|
||||
except Exception: self.skipTest("no NV device, test skipped")
|
||||
|
||||
x = amd_dev.signal_t()
|
||||
y = nv_dev.signal_t()
|
||||
assert type(x) is amd_dev.signal_t
|
||||
assert type(y) is nv_dev.signal_t
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -72,6 +72,13 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, {})
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
|
||||
def test_error_on_device_mismatch(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10, device="CPU")
|
||||
c = a+b
|
||||
with self.assertRaises(RuntimeError): check_schedule(c, 1)
|
||||
|
||||
def test_basic_binop_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest, math
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import all_same
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, exec_alu
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite, mulacc_unrolled
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
||||
|
||||
# Helper function to apply the graph rewrite
|
||||
def apply_rewrite(expr):
|
||||
@@ -275,41 +275,6 @@ class TestSubstitute(unittest.TestCase):
|
||||
ret = substitute(ret, {a.sin():a.sqrt(), n1.sin():n1.sqrt()})
|
||||
self.assertIs(ret, a.sqrt().sqrt())
|
||||
|
||||
class TestMulaccUnrolledAcc(unittest.TestCase):
|
||||
def test_unrolled2(self):
|
||||
acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))
|
||||
acc = UOp(Ops.DEFINE_ACC, dtypes.int, (UOp.const(dtypes.int, 0),) + acc_range, (0,))
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
expr = acc.assign(acc + (a*2 + b*3))
|
||||
expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled)
|
||||
self.assertIs(expr_with_mulacc, acc.assign(acc + a*2 + b*3))
|
||||
|
||||
def test_unrolled4_float(self):
|
||||
acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 3))
|
||||
acc = UOp(Ops.DEFINE_ACC, dtypes.float32, (UOp.const(dtypes.int, 0),)+acc_range, (0,))
|
||||
|
||||
a = [UOp.variable(f'a{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)]
|
||||
b = [UOp.variable(f'b{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)]
|
||||
|
||||
expr = acc.assign(acc + (a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3]))
|
||||
expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled)
|
||||
|
||||
# Verify it unrolls into individual multiply-accumulate operations
|
||||
expected = acc.assign(acc + a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3])
|
||||
self.assertIs(expr_with_mulacc, expected)
|
||||
|
||||
def test_unrolled4_float_const(self):
|
||||
acc_range = (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 3))
|
||||
acc = UOp(Ops.DEFINE_ACC, dtypes.float32, (UOp.const(dtypes.int, 0),)+acc_range, (0,))
|
||||
|
||||
a = [UOp.variable(f'a{i}', float("-inf"), float("inf"), dtype=dtypes.float32) for i in range(4)]
|
||||
expr = acc.assign(acc + (a[0]*3.0 + a[1]*4.0 + a[2]*5.0 + a[3]*6.0))
|
||||
expr_with_mulacc = graph_rewrite(expr, mulacc_unrolled)
|
||||
|
||||
# Verify it unrolls into individual multiply-accumulate operations
|
||||
expected = acc.assign(acc + a[0]*3.0 + a[1]*4.0 + a[2]*5.0 + a[3]*6.0)
|
||||
self.assertIs(expr_with_mulacc, expected)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
|
||||
from tinygrad.ops import graph_rewrite, GroupOp
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, mulacc_unrolled
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym
|
||||
from tinygrad.helpers import getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -232,12 +232,11 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
if DEVECTORIZE:
|
||||
# devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
|
||||
mulacc_unrolled)
|
||||
# devectorize + load_store_indexing
|
||||
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing)
|
||||
else:
|
||||
# new devectorize only for load/store
|
||||
sink = graph_rewrite(sink, sym+devectorize_load_store+mulacc_unrolled)
|
||||
sink = graph_rewrite(sink, sym+devectorize_load_store)
|
||||
|
||||
# optional pre matcher
|
||||
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
|
||||
|
||||
@@ -115,14 +115,17 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
|
||||
alu_op: Ops = x.arg[0]
|
||||
ret = x.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
|
||||
if not len(reduce_range): return ret
|
||||
# create ACC and assign
|
||||
# create acc
|
||||
acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
|
||||
ctx.acc_num += 1
|
||||
return acc.assign(acc.alu(alu_op, ret))
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [acc]+[ret.gep(i) for i in range(ret.dtype.count)])
|
||||
else:
|
||||
ret = acc.alu(alu_op, ret)
|
||||
if not len(reduce_range): return ret
|
||||
# create ACC and assign
|
||||
return acc.assign(ret)
|
||||
|
||||
def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
|
||||
|
||||
@@ -318,9 +318,10 @@ def threefry2x32(x: UOp, key: UOp):
|
||||
|
||||
# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
|
||||
|
||||
def loop_collapse(compval, multconst, rng:UOp, acc:UOp, idx2=None,idx3=None,extra=None,vec=None,ne=None,
|
||||
def loop_collapse(compval, multconst, rng:UOp, acc:UOp, extra:UOp, idx2=None,idx3=None,vec=None,ne=None,
|
||||
add=UOp.const(dtypes.int, 0), mul:UOp=UOp.const(dtypes.int, 1)):
|
||||
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in acc.src: return None # must be the right REDUCE
|
||||
if acc not in split_uop(extra, Ops.ADD): return None
|
||||
loop_start, loop_end = rng.src
|
||||
if loop_start.arg != 0:
|
||||
# TODO: support and test this with other mul and loop_starts
|
||||
@@ -344,7 +345,7 @@ def loop_collapse(compval, multconst, rng:UOp, acc:UOp, idx2=None,idx3=None,extr
|
||||
# TODO: what does it mean to have the same numbered DEFINE_ACC with different ranges?
|
||||
new_acc = acc.replace(src=acc.src[0:1]+tuple(x for x in acc.src[1:] if x is not rng))
|
||||
ret = new_acc.assign(new_acc+new_reduce_op)
|
||||
if extra is not None: ret = ret + acc.assign(acc+extra)
|
||||
if extra is not acc: ret = ret + acc.assign(extra)
|
||||
return ret
|
||||
|
||||
def index_collapse(idx:UOp,rng:UOp,buf:UOp,ld:UOp,acc:UOp,add=UOp.const(dtypes.int, 0),mul=UOp.const(dtypes.int, 1)):
|
||||
@@ -383,9 +384,6 @@ index_load = UPat.var("buf").index(rng_aug).load(name="ld")
|
||||
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
|
||||
arange_m = ((arange_augrng<UPat.cvar("compval"))!=UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
|
||||
|
||||
# this moves the accumulation variable down an unrolled add chain which allows for more efficient accumulation using mulacc
|
||||
mulacc_unrolled = PatternMatcher([(UPat.var("x")+UPat.var("y")+acc_pat, lambda x,y,acc: (acc+x)+y if y.op is not Ops.DEFINE_ACC else None)])
|
||||
|
||||
# this is symbolic 2.0
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# self ASSIGN is just self
|
||||
@@ -435,7 +433,7 @@ sym = symbolic_flat+PatternMatcher([
|
||||
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
|
||||
lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
|
||||
# arange loop folding
|
||||
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
|
||||
(acc_pat.assign(arange_m+UPat.var("extra")), loop_collapse),
|
||||
# indexing, with cast or where
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
|
||||
|
||||
@@ -54,6 +54,11 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
|
||||
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
||||
|
||||
def create_buffer_view(tr:UOp, x:UOp):
|
||||
assert isinstance(tr.device, str), "device must be string"
|
||||
if not tr.device.startswith("DISK"): return None
|
||||
return UOp(Ops.BUFFER_VIEW, tr.dtype, (x.base,), (tr.size, unwrap(x.st).views[0].offset)).reshape(tr.shape)
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
||||
@@ -93,8 +98,7 @@ sym = symbolic_simple+PatternMatcher([
|
||||
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
||||
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
||||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
|
||||
lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
|
||||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="tr"), create_buffer_view),
|
||||
# put UnaryOps before EXPANDs
|
||||
(UPat(GroupOp.Unary, src=UPat(Ops.VIEW, src=(UPat.var("inp"),), name="v"), name="alu"),
|
||||
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
|
||||
@@ -252,6 +256,9 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
|
||||
create_kernels = merge_views+PatternMatcher([
|
||||
(UPat(GroupOp.All-{Ops.KERNEL, Ops.BUFFER}, name="x"), create_kernel),
|
||||
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
||||
# remove CONST/BIND from the kernel graph
|
||||
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
|
||||
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None),
|
||||
])
|
||||
|
||||
# **** fix kernel AST
|
||||
@@ -377,12 +384,13 @@ def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
|
||||
ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
|
||||
# add buffer ops
|
||||
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
# unbind_vars + push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
|
||||
# fix_kernel_ops
|
||||
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
|
||||
# create subbuffer
|
||||
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = bufs[1].buffer.view(ast.size, ast.dtype, (x:=ast.src[0]).st_arg.views[0].offset*x.dtype.itemsize)
|
||||
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
|
||||
@@ -289,7 +289,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.MULTI:
|
||||
return ShapeTracker.from_shape(
|
||||
tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
|
||||
if self.op is Ops.BUFFER: return ShapeTracker.from_shape((self.size,))
|
||||
if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return ShapeTracker.from_shape((self.size,))
|
||||
if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape)
|
||||
# these ops define a ShapeTracker from the arg
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
@@ -298,7 +298,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op in GroupOp.Buffer: return vsrc[0] if len(vsrc:=[x.st for x in self.src if x.op is Ops.VIEW]) != 0 else None
|
||||
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
||||
assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
|
||||
if self.op in {Ops.BITCAST, Ops.BUFFER_VIEW}:
|
||||
if self.op is Ops.BITCAST:
|
||||
shape = src_sts[0].shape
|
||||
if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,)
|
||||
# only reduce ops are allowed to change shape, everything else derives shape from sources
|
||||
@@ -316,7 +316,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
|
||||
@property
|
||||
def size(self) -> int: return self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
|
||||
def size(self) -> int: return self.arg[0] if self.op is Ops.BUFFER_VIEW else self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
|
||||
|
||||
# *** uop evaluation ***
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, ClassVar
|
||||
import os, ctypes, ctypes.util, functools, mmap, errno, array, contextlib, sys, select
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
@@ -27,10 +27,7 @@ def nbioreg(reg): return reg + 0x00000d20 # NBIO_BASE__INST0_SEG2
|
||||
|
||||
class AMDSignal(HCQSignal):
|
||||
def __init__(self, base_addr:int|None=None, **kwargs):
|
||||
super().__init__(AMDDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=100)
|
||||
|
||||
def __del__(self):
|
||||
if isinstance(self.base_addr, int): AMDDevice.signals_pool.append(self.base_addr)
|
||||
super().__init__(base_addr, **kwargs, timestamp_divider=100, dev_t=AMDDevice)
|
||||
|
||||
def _sleep(self, time_spent_waiting_ms:int):
|
||||
# Resonable to sleep for long workloads (which take more than 2s) and only timeline signals.
|
||||
@@ -562,9 +559,11 @@ class PCIIface:
|
||||
def device_fini(self): self.adev.fini()
|
||||
|
||||
class AMDDevice(HCQCompiled):
|
||||
devices: ClassVar[list[HCQCompiled]] = []
|
||||
signal_pages: ClassVar[list[Any]] = []
|
||||
signal_pool: ClassVar[list[int]] = []
|
||||
|
||||
driverless:bool = not HWInterface.exists('/sys/module/amdgpu') or bool(getenv("AMD_DRIVERLESS", 0))
|
||||
signals_page:Any = None
|
||||
signals_pool:list[int] = []
|
||||
|
||||
def __init__(self, device:str=""):
|
||||
self.device_id = int(device.split(":")[1]) if ":" in device else 0
|
||||
@@ -573,11 +572,6 @@ class AMDDevice(HCQCompiled):
|
||||
self.arch = "gfx%d%x%x" % (self.target // 10000, (self.target // 100) % 100, self.target % 100)
|
||||
if self.target < 100300 or self.target >= 120000: raise RuntimeError(f"Unsupported arch: {self.arch}")
|
||||
|
||||
if AMDDevice.signals_page is None:
|
||||
AMDDevice.signals_page = self.dev_iface.alloc(16 * 65536, host=True, uncached=True, cpu_access=True)
|
||||
AMDDevice.signals_pool = [AMDDevice.signals_page.va_addr + off for off in range(0, AMDDevice.signals_page.size, 16)]
|
||||
else: self.dev_iface.map(AMDDevice.signals_page)
|
||||
|
||||
self.max_cu_id = self.dev_iface.props['simd_count'] // self.dev_iface.props['simd_per_cu'] - 1
|
||||
self.max_wave_id = self.dev_iface.props['max_waves_per_simd'] * self.dev_iface.props['simd_per_cu'] - 1
|
||||
self.has_scratch_base_registers = self.target >= 110000
|
||||
|
||||
@@ -169,7 +169,8 @@ class DSPDevice(Compiled):
|
||||
except (OSError, PermissionError):
|
||||
# DSP might ask for a connection reset or just fail with operation not permitted, try to reset connection.
|
||||
self.init_dsp()
|
||||
_exec_lib()
|
||||
try: _exec_lib()
|
||||
except (OSError, PermissionError) as e: raise RuntimeError(e)
|
||||
|
||||
def init_dsp(self):
|
||||
if hasattr(self, 'rpc_fd'):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, ctypes, contextlib, re, functools, mmap, struct, array, sys
|
||||
assert sys.platform != 'win32'
|
||||
from typing import Any, cast, Union, Type
|
||||
from typing import Any, cast, Union, Type, ClassVar
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQProgram, HCQSignal, BumpAllocator
|
||||
from tinygrad.runtime.support.hcq import HWInterface, MOCKGPU
|
||||
@@ -73,10 +73,7 @@ assert ctypes.sizeof(qmd_struct_t) == 0x40 * 4
|
||||
|
||||
class NVSignal(HCQSignal):
|
||||
def __init__(self, base_addr:int|None=None, **kwargs):
|
||||
super().__init__(NVDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=1000, value_off=0, timestamp_off=8)
|
||||
|
||||
def __del__(self):
|
||||
if isinstance(self.base_addr, int): NVDevice.signals_pool.append(self.base_addr)
|
||||
super().__init__(base_addr, **kwargs, timestamp_divider=1000, dev_t=NVDevice)
|
||||
|
||||
class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
|
||||
def __init__(self):
|
||||
@@ -285,12 +282,14 @@ class GPFifo:
|
||||
|
||||
MAP_FIXED, MAP_NORESERVE = 0x10, 0x400
|
||||
class NVDevice(HCQCompiled[NVSignal]):
|
||||
devices: ClassVar[list[HCQCompiled]] = []
|
||||
signal_pages: ClassVar[list[Any]] = []
|
||||
signal_pool: ClassVar[list[int]] = []
|
||||
|
||||
root = None
|
||||
fd_ctl: HWInterface
|
||||
fd_uvm: HWInterface
|
||||
gpus_info: Union[list, ctypes.Array] = []
|
||||
signals_page: Any = None
|
||||
signals_pool: list[int] = []
|
||||
|
||||
# TODO: Need a proper allocator for va addresses
|
||||
# 0x1000000000 - 0x2000000000, reserved for system/cpu mappings
|
||||
@@ -433,11 +432,6 @@ class NVDevice(HCQCompiled[NVSignal]):
|
||||
try: uvm.enable_peer_access(self.fd_uvm, gpuUuidA=self.gpu_uuid, gpuUuidB=dev.gpu_uuid)
|
||||
except RuntimeError as e: raise RuntimeError(str(e) + f". Make sure GPUs #{self.gpu_minor} & #{dev.gpu_minor} have P2P enabled between.") from e
|
||||
|
||||
if NVDevice.signals_page is None:
|
||||
NVDevice.signals_page = self._gpu_alloc(16 * 65536, cpu_access=True, uncached=True)
|
||||
NVDevice.signals_pool = [self.signals_page.va_addr + off for off in range(0, NVDevice.signals_page.size, 16)]
|
||||
else: self._gpu_map(NVDevice.signals_page)
|
||||
|
||||
channel_params = nv_gpu.NV_CHANNEL_GROUP_ALLOCATION_PARAMETERS(engineType=nv_gpu.NV2080_ENGINE_TYPE_GRAPHICS)
|
||||
channel_group = rm_alloc(self.fd_ctl, nv_gpu.KEPLER_CHANNEL_GROUP_A, self.root, self.nvdevice, channel_params).hObjectNew
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import os, ctypes, functools, mmap, struct, array, math, sys
|
||||
assert sys.platform != 'win32'
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, ClassVar
|
||||
from tinygrad.device import BufferSpec
|
||||
from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQAllocatorBase, HCQSignal, HCQArgsState, BumpAllocator
|
||||
from tinygrad.runtime.support.hcq import HWInterface
|
||||
@@ -38,10 +38,7 @@ class QCOMCompiler(CLCompiler):
|
||||
|
||||
class QCOMSignal(HCQSignal):
|
||||
def __init__(self, base_addr:int|None=None, **kwargs):
|
||||
super().__init__(QCOMDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=19.2)
|
||||
|
||||
def __del__(self):
|
||||
if isinstance(self.base_addr, int): QCOMDevice.signals_pool.append(self.base_addr)
|
||||
super().__init__(base_addr, **kwargs, timestamp_divider=19.2, dev_t=QCOMDevice)
|
||||
|
||||
def _sleep(self, time_spent_waiting_ms:int):
|
||||
# Sleep only for only timeline signals. Do it immediately to free cpu.
|
||||
@@ -320,16 +317,16 @@ class QCOMAllocator(HCQAllocatorBase):
|
||||
self.dev._gpu_free(opaque)
|
||||
|
||||
class QCOMDevice(HCQCompiled):
|
||||
signals_page: Any = None
|
||||
signals_pool: list[int] = []
|
||||
devices: ClassVar[list[HCQCompiled]] = []
|
||||
signal_pages: ClassVar[list[Any]] = []
|
||||
signal_pool: ClassVar[list[int]] = []
|
||||
|
||||
gpu_id: int = 0
|
||||
dummy_addr: int = 0
|
||||
|
||||
def __init__(self, device:str=""):
|
||||
self.fd = HWInterface('/dev/kgsl-3d0', os.O_RDWR)
|
||||
QCOMDevice.dummy_addr = cast(int, self._gpu_alloc(0x1000).va_addr)
|
||||
QCOMDevice.signals_page = self._gpu_alloc(16 * 65536, uncached=True)
|
||||
QCOMDevice.signals_pool = [self.signals_page.va_addr + off for off in range(0, self.signals_page.size, 16)]
|
||||
|
||||
flags = kgsl.KGSL_CONTEXT_PREAMBLE | kgsl.KGSL_CONTEXT_PWR_CONSTRAINT | kgsl.KGSL_CONTEXT_NO_FAULT_TOLERANCE | kgsl.KGSL_CONTEXT_NO_GMEM_ALLOC \
|
||||
| kgsl.KGSL_CONTEXT_PRIORITY(8) | kgsl.KGSL_CONTEXT_PREEMPT_STYLE(kgsl.KGSL_CONTEXT_PREEMPT_STYLE_FINEGRAIN)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import cast, Type, TypeVar, Generic, Any
|
||||
from typing import cast, Type, TypeVar, Generic, Any, ClassVar
|
||||
import contextlib, decimal, statistics, time, ctypes, array, os, fcntl
|
||||
from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -203,15 +203,20 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
||||
def _submit(self, dev:DeviceType): raise NotImplementedError("need _submit")
|
||||
|
||||
class HCQSignal(Generic[DeviceType]):
|
||||
def __init__(self, base_addr:sint=0, value:int=0, timeline_for_device:DeviceType|None=None, timestamp_divider=1, value_off=0, timestamp_off=8):
|
||||
self.base_addr, self.value_addr, self.timestamp_addr = base_addr, base_addr+value_off, base_addr+timestamp_off
|
||||
def __init__(self, base_addr:sint|None=None, value:int=0, dev_t:Type[DeviceType]|None=None, timeline_for_device:DeviceType|None=None,
|
||||
timestamp_divider=1, value_off=0, timestamp_off=8):
|
||||
self.base_addr = dev_t._alloc_signal_addr() if dev_t is not None and base_addr is None else base_addr
|
||||
self.value_addr, self.timestamp_addr, self.dev_t = self.base_addr+value_off, self.base_addr+timestamp_off, dev_t
|
||||
self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
|
||||
self.timeline_for_device:DeviceType|None = timeline_for_device
|
||||
|
||||
if isinstance(base_addr, int):
|
||||
if isinstance(self.base_addr, int):
|
||||
self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q')
|
||||
self.value_mv[0] = value
|
||||
|
||||
def __del__(self):
|
||||
if isinstance(self.base_addr, int) and self.dev_t is not None: self.dev_t.signal_pool.append(self.base_addr)
|
||||
|
||||
@property
|
||||
def value(self) -> int: return self.value_mv[0]
|
||||
|
||||
@@ -332,23 +337,29 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
"""
|
||||
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
|
||||
"""
|
||||
devices: list[HCQCompiled] = []
|
||||
devices: ClassVar[list[HCQCompiled]] = []
|
||||
signal_pages: ClassVar[list[Any]] = []
|
||||
signal_pool: ClassVar[list[int]] = []
|
||||
|
||||
def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
|
||||
comp_queue_t:Type[HWQueue], copy_queue_t:Type[HWQueue]|None):
|
||||
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
||||
|
||||
# Map signals if any
|
||||
for sig_page in self.signal_pages: cast(HCQAllocator, self.allocator).map(sig_page)
|
||||
self.devices.append(self)
|
||||
|
||||
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
|
||||
self.timeline_value:int = 1
|
||||
self.timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
|
||||
self._shadow_timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
|
||||
self.sig_prof_records:list[tuple[HCQSignal, HCQSignal, str, bool]] = []
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
||||
|
||||
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferSpec(cpu_access=True))
|
||||
self.kernargs_allocator:BumpAllocator = BumpAllocator(self.kernargs_page.size, base=cast(int, self.kernargs_page.va_addr), wrap=True)
|
||||
self.devices.append(self)
|
||||
|
||||
def synchronize(self):
|
||||
try: self.timeline_signal.wait(self.timeline_value - 1)
|
||||
@@ -361,6 +372,14 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
Compiled.profile_events += [ProfileRangeEvent(self.device, name, st.timestamp, en.timestamp, cp) for st,en,name,cp in self.sig_prof_records]
|
||||
self.sig_prof_records = []
|
||||
|
||||
@classmethod
|
||||
def _alloc_signal_addr(cls) -> int:
|
||||
if not cls.signal_pool:
|
||||
cls.signal_pages.append(alc:=cls.devices[0].allocator.alloc(0x1000, BufferSpec(host=True, uncached=True, cpu_access=True)))
|
||||
cls.signal_pool += [alc.va_addr + off for off in range(0, alc.size, 16)]
|
||||
for dev in cls.devices: cast(HCQAllocator, dev.allocator).map(alc)
|
||||
return cls.signal_pool.pop()
|
||||
|
||||
def _at_profile_finalize(self):
|
||||
def _sync(d:HCQCompiled, q_t:Type[HWQueue]):
|
||||
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from typing import cast
|
||||
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
|
||||
from tinygrad.helpers import all_same, dedup, prod
|
||||
from tinygrad.helpers import all_same, all_int, dedup, prod
|
||||
|
||||
buffer_spec = PatternMatcher([
|
||||
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
||||
(UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"),
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
|
||||
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all_int(buf_view.arg)),
|
||||
])
|
||||
|
||||
# *** this is the spec of a Tensor in UOp ***
|
||||
@@ -126,9 +128,7 @@ kernel_spec = buffer_spec+PatternMatcher([
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
|
||||
# assign has a buffer view and kernel source, it can optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
|
||||
# view/sink/const/bind/var can also exist in the kernel graph
|
||||
(UPat((Ops.VIEW, Ops.SINK, Ops.CONST, Ops.BIND, Ops.DEFINE_VAR)), lambda: True),
|
||||
(UPat(GroupOp.All), lambda: False),
|
||||
(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}), lambda: False),
|
||||
])
|
||||
|
||||
# *** this is the UOp shape spec ***
|
||||
|
||||
@@ -13,7 +13,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
|
||||
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.NAME:"#808080"}
|
||||
|
||||
# VIZ API
|
||||
|
||||
Reference in New Issue
Block a user