Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2024-09-26 02:20:28 -07:00
28 changed files with 209 additions and 166 deletions

View File

@@ -140,6 +140,15 @@ jobs:
python -c "from tinygrad.tensor import Tensor; print(Tensor([1,2,3,4,5]))"
pip install mypy
mypy -c "from tinygrad.tensor import Tensor; print(Tensor([1,2,3,4,5]))"
- name: Run beautiful_mnist without numpy
run: |
mkdir $HOME/test_no_numpy_dir
cd $HOME/test_no_numpy_dir
python -m venv venv
source venv/bin/activate
pip install $GITHUB_WORKSPACE
cp $GITHUB_WORKSPACE/examples/beautiful_mnist.py .
PYTHONPATH=$GITHUB_WORKSPACE BS=2 STEPS=10 python beautiful_mnist.py
- name: Test DEBUG
run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
- name: Repo line count <9800 lines

View File

@@ -8,7 +8,7 @@ from tinygrad.ops import UOps
from tinygrad.device import Compiled
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.helpers import DEBUG, ansilen, getenv, colored
from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA
from tinygrad.shape.symbolic import sym_infer
def get_sched_resnet():
@@ -53,7 +53,7 @@ def get_sched_bert():
# ignore grad norm and loss scaler for now
loss.backward()
targets += [x.lazydata for x in optim.schedule_step()]
sched = create_schedule(targets, seen)
sched = create_schedule(targets)
print(f"schedule length {len(sched)}")
return sched
@@ -128,7 +128,7 @@ if __name__ == "__main__":
running_gflops += gflops * tm
if (key := str([str(m) for m in si.metadata] if si.metadata is not None else None)) not in usage: usage[key] = (0, 0)
usage[key] = (usage[key][0] + tm, usage[key][1] + 1)
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(prg.global_size):18s} {str(prg.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS {[str(m) for m in si.metadata] if si.metadata is not None else ''}")
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(prg.global_size):18s} {str(prg.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS {[repr(m) if TRACEMETA >= 2 else str(m) for m in si.metadata] if si.metadata is not None else ''}")
print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")
print("usage:")
for k in sorted(usage, key=lambda x: -usage[x][0])[:10]:

View File

@@ -644,6 +644,13 @@ def train_bert():
else:
MLLOGGER = None
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
# ** hyperparameters **
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
@@ -672,7 +679,7 @@ def train_bert():
Tensor.manual_seed(seed) # seed for weight initialization
model = get_mlperf_bert_model(init_ckpt if not INITMLPERF else None)
model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None)
for _, x in get_state_dict(model).items():
x.realize().to_(GPUS)
@@ -727,14 +734,8 @@ def train_bert():
start_step = int(scheduler_wd.epoch_counter.numpy().item())
print(f"resuming from {ckpt} at step {start_step}")
# ** init wandb **
WANDB = getenv("WANDB")
if WANDB:
import wandb
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
if not INITMLPERF:
if RUNMLPERF:
# only load real data with RUNMLPERF
eval_it = iter(batch_load_val_bert(EVAL_BS))
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
for _ in range(start_step): next(train_it) # Fast forward
@@ -743,10 +744,12 @@ def train_bert():
step_times = []
# ** train loop **
wc_start = time.perf_counter()
if INITMLPERF:
i, train_data = start_step, get_fake_data_bert(GPUS, BS)
else:
if RUNMLPERF:
# only load real data with RUNMLPERF
i, train_data = start_step, get_data_bert(GPUS, train_it)
else:
i, train_data = start_step, get_fake_data_bert(GPUS, BS)
while train_data is not None and i < train_steps and not achieved:
Tensor.training = True
BEAM.value = TRAIN_BEAM
@@ -759,10 +762,10 @@ def train_bert():
pt = time.perf_counter()
try:
if INITMLPERF:
next_data = get_fake_data_bert(GPUS, BS)
else:
if RUNMLPERF:
next_data = get_data_bert(GPUS, train_it)
else:
next_data = get_fake_data_bert(GPUS, BS)
except StopIteration:
next_data = None
@@ -807,10 +810,10 @@ def train_bert():
BEAM.value = EVAL_BEAM
for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
if INITMLPERF:
eval_data = get_fake_data_bert(GPUS, EVAL_BS)
else:
if RUNMLPERF:
eval_data = get_data_bert(GPUS, eval_it)
else:
eval_data = get_fake_data_bert(GPUS, EVAL_BS)
GlobalCounters.reset()
st = time.time()

View File

@@ -2,14 +2,11 @@
export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
export BEAM=4
export BASEDIR="/raid/datasets/wiki"
echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN"
export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run.
export BENCHMARK=10 DEBUG=2
python3 examples/mlperf/model_train.py

View File

@@ -2,14 +2,11 @@
export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
export BEAM=4
export BASEDIR="/raid/datasets/wiki"
echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN"
export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run.
export WANDB=1
python3 examples/mlperf/model_train.py

View File

@@ -3,14 +3,11 @@
export PYTHONPATH="."
export MODEL="bert"
export SUBMISSION_PLATFORM="tinybox_green"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
export BEAM=4
export BASEDIR="/raid/datasets/wiki"
echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN"
export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run.
# pip install -e ".[mlperf]"
export LOGMLPERF=1

View File

@@ -2,14 +2,11 @@
export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=84 EVAL_BS=6
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
export BEAM=4
export BASEDIR="/raid/datasets/wiki"
echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN"
export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run.
export BENCHMARK=10 DEBUG=2
python3 examples/mlperf/model_train.py

View File

@@ -2,14 +2,11 @@
export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=84 EVAL_BS=6
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
export BEAM=4
export BASEDIR="/raid/datasets/wiki"
echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN"
export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run.
export WANDB=1
python3 examples/mlperf/model_train.py

View File

@@ -3,14 +3,11 @@
export PYTHONPATH="."
export MODEL="bert"
export SUBMISSION_PLATFORM="tinybox_red"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=84 EVAL_BS=6
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
export BEAM=4
export BASEDIR="/raid/datasets/wiki"
echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN"
export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run.
# pip install -e ".[mlperf]"
export LOGMLPERF=1

View File

@@ -290,5 +290,5 @@ if __name__ == "__main__":
if args.prompt == default_prompt and args.steps == 6 and args.seed == 0 and args.guidance == 7.5:
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
assert distance < 45e-5, colored(f"validation failed with {distance=}", "red")
assert distance < 50e-5, colored(f"validation failed with {distance=}", "red")
print(colored(f"output validated with {distance=}", "green"))

View File

@@ -49,15 +49,15 @@ class BertForPretraining:
output = self.bert(input_ids, attention_mask, token_type_ids)
return self.cls(output, masked_lm_positions)
def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
# Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315
def sparse_categorical_crossentropy(predictions:Tensor, labels:Tensor, ignore_index=-1):
log_probs, loss_mask = predictions.log_softmax(), (labels != ignore_index)
y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1])
return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero
# Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315
def sparse_categorical_crossentropy(self, predictions:Tensor, labels:Tensor, ignore_index=-1):
log_probs, loss_mask = predictions.log_softmax(), (labels != ignore_index)
y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1])
return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero
masked_lm_loss = sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)
def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
return masked_lm_loss + next_sentence_loss
@@ -66,7 +66,7 @@ class BertForPretraining:
valid = masked_lm_ids != 0
masked_lm_predictions = prediction_logits.log_softmax().argmax(-1)
masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid
masked_lm_loss = prediction_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)
seq_relationship_predictions = seq_relationship_logits.log_softmax().argmax(-1)
seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels)

View File

@@ -21,7 +21,7 @@ setup(name='tinygrad',
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License"
],
install_requires=["numpy"],
install_requires=[],
python_requires='>=3.8',
extras_require={
'llvm': ["llvmlite"],
@@ -37,6 +37,7 @@ setup(name='tinygrad',
],
#'mlperf': ["mlperf-logging @ git+https://github.com/mlperf/logging.git@4.0.0-rc2"],
'testing': [
"numpy",
"torch",
"pillow",
"pytest",
@@ -63,7 +64,8 @@ setup(name='tinygrad',
"mkdocstrings[python]",
"markdown-callouts",
"markdown-exec[ansi]",
"black"
"black",
"numpy",
],
'testing_tf': [
"tensorflow==2.15.1",

View File

@@ -74,7 +74,9 @@ def benchmark_model(m, devices, validate_outs=False):
del inputs, tinygrad_model, tinygrad_jitted_model
except RuntimeError as e:
# TODO: we don't run the dm model on METAL for now
if Device.DEFAULT == "METAL": assert "buffer count limit" in str(e)
if Device.DEFAULT == "METAL":
assert "buffer count limit" in str(e)
return
else: raise e
# convert model to torch

View File

@@ -578,7 +578,7 @@ class TestAutoCastType(unittest.TestCase):
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
@given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_int(d) and is_dtype_supported(d)]))
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
def test_int_to_float_unary_func(self, dtype):
for func in [
lambda t: t.exp(),

View File

@@ -477,7 +477,7 @@ class TestSchedule(unittest.TestCase):
def test_double_from(self):
x = Tensor([1,2,3,4])
out = x.to('npy')
out = x.to('python')
check_schedule(out, 0, filter_sink=False)
def test_pow_const_tensor_simplified(self):

View File

@@ -388,11 +388,6 @@ class TestUOpStr(unittest.TestCase):
sink = UOp(UOps.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
assert_equiv_uops(sink, eval(str(sink)))
def test_variable_const(self):
# TODO: this is not possible after VALID.
uop = UOp(UOps.CONST, dtypes.int, (), arg=Variable("a",1,10))
assert str(eval(str(uop))) == str(uop)
def test_vectorized_str(self):
vec = UOp(UOps.VECTORIZE, dtypes.int.vec(4), tuple(UOp.const(dtypes.int, x) for x in range(4)))
assert str(eval(str(vec))) == str(vec)

View File

@@ -94,6 +94,7 @@ class TestUOpsStats(unittest.TestCase):
# NOTE; ops also include indexing ops
assert expected_ops <= ops and ops <= expected_ops * 2
@unittest.skipIf(getenv("PTX"), "wrong in PTX")
def test_simple_add_sq(self):
a = Tensor.empty(100,100)
b = Tensor.empty(100,100)

View File

@@ -3,6 +3,7 @@ from PIL import Image
from tinygrad.helpers import Context, ContextVar
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape
from tinygrad.shape.symbolic import Variable, NumNode
import numpy as np
VARIABLE = ContextVar("VARIABLE", 0)
@@ -188,6 +189,12 @@ class TestFullyFlatten(unittest.TestCase):
self.assertEqual(fully_flatten([[1, 2, [3, 4]], [5, 6], 7]), [1, 2, 3, 4, 5, 6, 7])
self.assertEqual(fully_flatten([[1, "ab"], [True, None], [3.14, [5, "b"]]]), [1, "ab", True, None, 3.14, 5, "b"])
def test_fully_flatten_numpy(self):
self.assertEqual(fully_flatten([np.array([1, 3]), np.array([1, 2])]), [1, 3, 1, 2])
self.assertEqual(fully_flatten((np.array([1, 3]), np.array([1, 2]))), [1, 3, 1, 2])
self.assertEqual(fully_flatten([np.array([[1], [3]]), np.array([[1], [2]])]), [1, 3, 1, 2])
self.assertEqual(fully_flatten([[1, "ab"], [True, None], np.array([[3.14], [6.28]])]), [1, "ab", True, None, 3.14, 6.28])
class TestMemoryview(unittest.TestCase):
def test_from_mv_to_mv(self):
base = memoryview(bytearray(b"\x11\x22\x33"*40))

View File

@@ -4,7 +4,7 @@ import functools, itertools, heapq, math, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite, TernaryOps
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
if TYPE_CHECKING: from tinygrad.renderer import Renderer
@@ -266,12 +266,30 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid else (buf, idx)))
# ***** transcendental *****
# ***** optional patterns *****
transcendental_patterns = [
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.EXP2), xexp2),
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.LOG2), xlog2),
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.SIN), xsin),
]
@functools.lru_cache(None)
def transcendental_folding(ops):
return PatternMatcher([(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=k), cast(Callable, v))
for k,v in ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if k not in ops])
def get_extra_patterns(ops, force_transcendental=False):
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental]
if BinaryOps.SHL in ops and BinaryOps.SHR in ops:
shiftable_consts = set([2**i for i in range(64)])
pat += [
(UPat(UOps.ALU, arg=BinaryOps.MUL, name="root", dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda root, mul, const:
UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None),
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", src=(UPat.var("div"), UPat.cvar("const"))), lambda root, div, const:
UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None)]
if UnaryOps.NEG in ops:
pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))]
if TernaryOps.MULACC in ops:
pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))]
return PatternMatcher(pat)
# ***** threefry *****
@@ -348,9 +366,10 @@ def no_vectorized_wmma(wmma:UOp):
# this is symbolic 2.0
constant_folder = PatternMatcher([
# bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.ADD, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)),
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),
(UPat.var('x', dtype=dtypes.bool).max(UPat.var('y')), lambda x,y: x|y),
# self ASSIGN is just self
(UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
# ASSIGN to global is just self
@@ -715,11 +734,10 @@ linearize_cnt = 0
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
global linearize_cnt, acc_number
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
folder = constant_folder + transcendental_folding(tuple() if TRANSCENDENTAL >= 2 or opts is None else tuple(opts.code_for_op.keys()))
# do graph rewrite
acc_number = 0
sink = graph_rewrite(sink, folder)
sink = graph_rewrite(sink, constant_folder)
# rewrite pyint to int32
sink = graph_rewrite(sink, no_pyint)
@@ -727,11 +745,12 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
# expand
linearize_cnt += 1
if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1:
sink = graph_rewrite(sink, folder+expander)
sink = graph_rewrite(sink, constant_folder+expander)
if getenv("DO_REDUCE", 1):
sink = graph_rewrite(sink, folder+just_reduce)
sink = graph_rewrite(sink, folder+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
sink = graph_rewrite(sink, folder+reducer)
sink = graph_rewrite(sink, constant_folder+just_reduce)
sink = graph_rewrite(sink, constant_folder+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
sink = graph_rewrite(sink, constant_folder+reducer)
sink = graph_rewrite(sink, constant_folder+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher)
return sink

View File

@@ -27,7 +27,7 @@ class ImageDType(DType):
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
base: DType
local: bool = False # images are never local
def scalar(self): return self.base
def scalar(self) -> DType: return self.base
def vec(self, sz:int): return self.base.vec(sz)
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
@@ -44,13 +44,13 @@ class PtrDType(DType):
class dtypes:
@staticmethod
@functools.lru_cache(None)
def is_float(x: DType) -> bool: return x.scalar() in {dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64}
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
@functools.lru_cache(None)
def is_int(x: DType) -> bool: return x.scalar() in {dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.pyint} or dtypes.is_unsigned(x)
def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
@staticmethod
@functools.lru_cache(None)
def is_unsigned(x: DType) -> bool: return x.scalar() in {dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64}
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
@staticmethod
def from_py(x) -> DType:
if x.__class__ is float: return dtypes.default_float
@@ -114,6 +114,11 @@ class dtypes:
default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32
floats = (float16, bfloat16, float32, float64)
uints = (uint8, uint16, uint32, uint64)
sints = (int8, int16, int32, int64, pyint)
ints = uints + sints
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
dtypes.default_float = getattr(dtypes, env_default_float.lower())
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
@@ -137,7 +142,8 @@ def least_upper_dtype(*ds:DType) -> DType:
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void')) or v.__class__ is staticmethod)}
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void'))
or v.__class__ is staticmethod or isinstance(v, tuple))}
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
INVERSE_DTYPES_DICT['void'] = 'void'

View File

@@ -3,7 +3,6 @@ from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast, get_args
from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps, PatternMatcher, UPat, graph_rewrite
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
from tinygrad.shape.symbolic import Variable, sint
@@ -42,6 +41,10 @@ class LBScheduleItem:
@property
def inputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:]
@dataclass(frozen=True)
class ScheduleItemContext:
bufs: Tuple[Buffer, ...]
# *** UOp with SWIZZLE (movementops) rewriting to UOp we can index ***
# ** helpers for doing movementops on uops
@@ -113,9 +116,14 @@ reduceop_fusor = PatternMatcher([
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
def full_ast_rewrite(sink:UOp) -> UOp:
enumerate_bufs = PatternMatcher([
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda ctx,x: x.replace(arg=ctx.bufs.index(x.arg)) if isinstance(x.arg, Buffer) else None),
])
def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp:
if not AST_REWRITE: return sink
return graph_rewrite(sink, reduceop_fusor)
sink = graph_rewrite(sink, reduceop_fusor)
return graph_rewrite(sink, enumerate_bufs, ctx)
# *** List[LazyBuffer] lowering to ScheduleItem ***
@@ -146,8 +154,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
if buf not in assign_targets and buf not in inputs: inputs.append(buf)
ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.index(buf))
ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), buf.buffer)
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
# reduce ops change ShapeTracker
@@ -175,16 +182,16 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {}
ast: List[UOp] = []
inputs: List[LazyBuffer] = []
for i, out in enumerate(outs):
for out in outs:
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
if out.op is MetaOps.ASSIGN and out.arg:
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
output_st = out.arg[0]
output_st, vv = output_st.simplify().unbind()
var_vals.update(vv)
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), out.buffer)
ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src)))
sink = full_ast_rewrite(ast[0].sink(*ast[1:]))
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(x.buffer for x in outs+inputs)))
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals
# *** DAG creation: decide which LazyBuffers should realize ***
@@ -194,7 +201,9 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None:
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
if buf in allbufs or buf.base.realized is not None: return
if GRAPH: log_lazybuffer(buf, scheduled)
if GRAPH:
from tinygrad.engine.graph import log_lazybuffer
log_lazybuffer(buf, scheduled)
# check if we need to realize views
if buf is not buf.base:
# fuse some pads
@@ -410,6 +419,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
lsi = queue.popleft()
if GRAPH:
kernel_number += 1
from tinygrad.engine.graph import realized_lazybuffer
for out in lsi.outputs: realized_lazybuffer(out, kernel_number)
for out in lsi.outputs: del out.srcs # can only schedule once
schedule.append(si:=ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.bufs if x.size != 0), lsi.metadata))

View File

@@ -32,7 +32,14 @@ def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
def ansilen(s:str): return len(ansistrip(s))
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist])]
def fully_flatten(l):
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
flattened = []
if hasattr(l, "shape") and l.shape == (): flattened.append(l[()])
else:
for i in range(len(l)): flattened.extend(fully_flatten(l[i]))
return flattened
return [l]
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
def round_up(num, amt:int): return (num+amt-1)//amt * amt
@@ -62,10 +69,12 @@ def get_child(obj, key):
return obj
def get_shape(x) -> Tuple[int, ...]:
if not isinstance(x, (list, tuple)): return ()
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str): return ()
if (aapi := (hasattr(x, "shape") and x.shape == ())): return ()
subs = [get_shape(xi) for xi in x]
if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
return (len(subs),) + (subs[0] if subs else ())
slen = 1 if aapi else len(subs)
return (slen,) + (subs[0] if subs else ())
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:

View File

@@ -16,7 +16,7 @@ def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
"""
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
json_len = t[0:8].bitcast(dtypes.int64).item()
return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
"""

View File

@@ -1,8 +1,7 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar
import sys, time, functools, itertools, math, operator, hashlib
from enum import auto, IntEnum, Enum
from collections import defaultdict
from dataclasses import dataclass, field
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import _CURRENT_KERNEL, ContextVar, pretty_print, prod, getenv, all_same
@@ -21,11 +20,11 @@ class FastEnum(IntEnum):
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
class UnaryOps(FastEnum):
"""A -> A (elementwise)"""
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
class BinaryOps(FastEnum):
"""A + A -> A (elementwise)"""
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702
class TernaryOps(FastEnum):
"""A + A + A -> A (elementwise)"""
WHERE = auto(); MULACC = auto() # noqa: E702
@@ -170,8 +169,7 @@ class UOp(MathTrait):
def key(self) -> bytes:
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
def argstr(self):
return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else repr(self.arg) if isinstance(self.arg, Variable) else self.arg
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
# *** uop syntactic sugar
@property
def st_arg(self) -> ShapeTracker:
@@ -304,7 +302,8 @@ python_alu: Dict[Op, Callable] = {
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
UnaryOps.NEG: operator.neg, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub,
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul,
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
@@ -412,7 +411,7 @@ class UPat(MathTrait):
return UPat((UOps.CONST, UOps.VCONST) if vec else UOps.CONST, dtype=dtype, name=name)
@staticmethod
@functools.lru_cache(None)
def const(dtype:Optional[DType], b:ConstType|Variable): return UPat(UOps.CONST, dtype=dtype, arg=b)
def const(dtype:Optional[DType], b:ConstType): return UPat(UOps.CONST, dtype=dtype, arg=b)
# copied from UOp
def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,))
@@ -442,7 +441,7 @@ class UPat(MathTrait):
def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
if (self.name is not None and store.setdefault(self.name, uop) is not uop) or \
(self.dtype is not None and uop.dtype not in self.dtype) or \
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
(self.arg is not None and self.arg != uop.arg) or \
(self.op is not None and uop.op not in self.op) or \
(self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
@@ -465,18 +464,19 @@ class UPatAny(UPat):
class PatternMatcher:
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
self.patterns = patterns
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = defaultdict(list)
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
self.pdict: Dict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = {}
# uop is required, arg is optional
for p,fxn in self.patterns:
assert p.op is not None
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn, p.early_reject))
for uop in p.op: self.pdict.setdefault((uop, p.arg), []).append((p, fxn, p.early_reject))
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])):
if not early_reject.issubset(ler): continue
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret
return None
@@ -501,7 +501,7 @@ class TrackedPatternMatcher(PatternMatcher):
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
ret = None
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])):
st = time.perf_counter()
if not early_reject.issubset(ler):
match_stats[p][2] += time.perf_counter()-st

View File

@@ -1,5 +1,5 @@
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
import struct, math
import struct
from collections import defaultdict
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat
from tinygrad.codegen.uopgraph import constant_folder
@@ -33,27 +33,15 @@ asm_for_op: Dict[Op, Callable] = {
}
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
shiftable_consts = set([2**i for i in range(64)])
ptx_matcher = constant_folder+PatternMatcher([
(UPat(UOps.ALU, arg=BinaryOps.MUL, name="root", dtype=tuple([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
src=[UPat.cvar("const"), UPat.var("mul")]),
lambda root, mul, const: UOp(UOps.ALU, root.dtype,
(mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None),
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", dtype=tuple([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
src=[UPat.cvar("const"), UPat.var("div")]),
lambda root, div, const: UOp(UOps.ALU, root.dtype,
(div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None),
(UPat(UOps.ALU, arg=BinaryOps.CMPNE, src=(UPat(dtype=dtypes.bool),UPat()), name="root"),
lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
(UPat(UOps.ALU, arg=BinaryOps.CMPLT, src=(UPat.var("x", dtype=dtypes.bool),UPat.var("y")), name="root"),
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x, UOp.const(dtypes.bool, True)), BinaryOps.CMPNE), y), BinaryOps.MUL)),
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat.var("non_muls"), UPat(UOps.ALU, arg=BinaryOps.MUL, name="muls")], name="root"),
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
# upcast to float32 all the ops that don't support half
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"),
lambda x: (UOp(x.op, dtypes.float32, tuple([vv.cast(dtypes.float32) for vv in x.src]), x.arg).cast(dtypes.half)))
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
for op in asm_for_op.keys() if op not in supports_half],
(UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX),
lambda x: UOp(UOps.ALU, dtypes.uint8, tuple(s.cast(dtypes.uint8) for s in x.src), x.arg).cast(dtypes.bool)),
# fix the gates for load/store (low quality!)
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat.var("x"),UPat.var("y"),UPat.var("z"),UPat.var("k"))),
lambda root,x,y,z,k: UOp(root.op, dtypes.uint8, (x,y,z.cast(dtypes.uint8),k)).cast(dtypes.bool)),
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
@@ -74,8 +62,7 @@ ptx_matcher = constant_folder+PatternMatcher([
lambda root, const: UOp(root.op, root.dtype,
(root.src[0].cast(dtypes.int64),
UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])),
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)),
UPat.var("alu"))), # no const here
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.var("alu"))),
lambda root, alu: UOp(root.op, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
UOp.const(dtypes.int64, 0))+root.src[2:])),

View File

@@ -8,13 +8,8 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.renderer import Renderer, TensorCore
def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str:
sidx = strip_parens(r[load.src[1]])
if isinstance(buf.dtype, ImageDType):
assert load.dtype == dtypes.float.vec(4), f"images must be float4, getting {load.dtype}"
val = f"read_imagef({r[buf]}, smp, {sidx})"
elif r.uses_vload and buf.dtype.scalar() == dtypes.float16 and load.dtype.scalar() != dtypes.float16:
val = f"vload_half{'' if load.dtype.count == 1 else str(load.dtype.count)}(0, {r[buf]}+{sidx})"
elif load.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
sidx = strip_parens(r[load.src[1]]) if load.src[1].arg == BinaryOps.ADD else r[load.src[1]]
if load.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
val = f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(load.dtype)}*)({r[buf]}+{sidx}))"
else:
val = f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]"
@@ -23,19 +18,15 @@ def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str:
if len(load.src) > 3 and load.src[3].op is UOps.ALU: val = r.code_for_op[TernaryOps.WHERE](r[load.src[3]], val, r[load.src[2]], load.dtype)
return val
def render_store(r:CStyleLanguage, store:UOp, buf:UOp, var:UOp) -> str:
sidx = strip_parens(r[store.src[1]])
if isinstance(buf.dtype, ImageDType):
assert var.dtype == dtypes.float.vec(4), f"images must be float4, getting {var.dtype}"
val = f"write_imagef({r[buf]}, {sidx}, {r[var]});"
elif r.uses_vload and buf.dtype.scalar() == dtypes.float16 and var.dtype.scalar() != dtypes.float16:
val = f"vstore_half{'' if var.dtype.count == 1 else str(var.dtype.count)}({r[var]}, 0, {r[buf]}+{sidx});"
elif var.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
def render_store(r:CStyleLanguage, buf:UOp, idx:UOp, var:UOp, gate:Optional[UOp]=None) -> str:
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
if var.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
prefix = r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix
val = f"*(({prefix}{r.render_dtype(var.dtype)}*)({r[buf]}+{sidx})) = {r[var]};"
else:
val = f"*({r[buf]}+{sidx}) = {r[var]};" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}] = {r[var]};"
return f"if ({r[store.src[3]]}) {{ {val} }}" if len(store.src) > 3 and store.src[3].op is not UOps.IF else val
# TODO: this if should be in UOps, not here
return f"if ({r[gate]}) {{ {val} }}" if gate is not None else val
def render_alu(r:CStyleLanguage, x:UOp):
if x.arg in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == x.arg else r[v] for v in x.src]
@@ -54,6 +45,16 @@ base_pm = PatternMatcher([
(UPat(UOps.IF, name="x"), lambda r,x: f"if ({r[x.src[0]]}) {{"),
(UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda r: "}"),
(UPat(UOps.WMMA, name="x"), lambda r,x: f"__{x.arg[0]}({r[x.src[0]]}, {r[x.src[1]]}, {r[x.src[2]]})"),
# load/store image
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))),
lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"),
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)))),
lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"),
# TODO: this if should be in UOps, not here
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4)),
UPat.var("gate", dtype=dtypes.bool))), lambda r,buf,idx,var,gate: f"if ({r[gate]}) {{ write_imagef({r[buf]}, {r[idx]}, {r[var]}); }}"),
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4))), allow_any_len=True),
lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"),
# r method accesses
(UPat(UOps.RANGE, name="x"), lambda r,x: f"for ({r.render_dtype(x.dtype)} {r[x]} = {r[x.src[0]]}; {r[x]} < {r[x.src[1]]}; {r[x]}++) {{"),
(UPat(UOps.VECTORIZE, name="x"),
@@ -77,7 +78,8 @@ base_pm = PatternMatcher([
(UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)),
# function calls
(UPat(UOps.LOAD, src=(UPat.var("buf"),), allow_any_len=True, name="load"), render_load),
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat(), UPat.var("var")), allow_any_len=True, name="store"), render_store),
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate", dtype=dtypes.bool))), render_store),
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True), render_store),
(UPat(UOps.ALU, name="x"), render_alu),
(UPat(UOps.GEP, name="x"), render_gep),
])
@@ -104,7 +106,6 @@ class CStyleLanguage(Renderer):
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
extra_args: List[str] = []
float4: Optional[str] = None
uses_vload: bool = False
uses_ptr_arithmetic: bool = False
type_map: Dict[DType, str] = {}
infinity: str = "INFINITY"
@@ -112,8 +113,10 @@ class CStyleLanguage(Renderer):
code_for_op: Dict = {
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
UnaryOps.NEG: lambda x,dtype: f"-{x}",
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
BinaryOps.SHL: lambda a,b,dtype: f"({a}<<{b})", BinaryOps.SHR: lambda a,b,dtype: f"({a}>>{b})",
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
@@ -240,7 +243,6 @@ class OpenCLRenderer(CStyleLanguage):
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
float4 = "(float4)"
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
uses_vload = True
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
def render_cast(self, x, var_dtype, bitcast=False) -> str:
return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)

View File

@@ -4,7 +4,6 @@ import time, math, itertools, functools, struct, sys, inspect, pathlib, string,
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
from collections import defaultdict
import numpy as np
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
@@ -44,10 +43,14 @@ def _metaop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str,
if isinstance(device, str): return LazyBuffer.metaop(op, shape, dtype, device, arg, src)
return MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, d, arg, src) for d in device], None)
def _from_np_dtype(npdtype:np.dtype) -> DType: return dtypes.fields()[np.dtype(npdtype).name]
def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
import numpy as np
return dtypes.fields()[np.dtype(npdtype).name]
def _to_np_dtype(dtype:DType) -> Optional[type]:
import numpy as np
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
def _fromnp(x: np.ndarray) -> LazyBuffer:
def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noqa: F821
ret = LazyBuffer.metaop(MetaOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
# fake realize
ret.buffer.allocate(x)
@@ -62,7 +65,7 @@ def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
truncate_function = truncate[dtype]
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
# fake realize
ret.buffer.allocate(memoryview(data))
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
del ret.srcs
return ret
@@ -106,7 +109,7 @@ class Tensor:
training: ClassVar[bool] = False
no_grad: ClassVar[bool] = False
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable, pathlib.Path],
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, 'np.ndarray', bytes, MultiLazyBuffer, Variable, pathlib.Path], # type: ignore [name-defined] # noqa: F821
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
if dtype is not None: dtype = to_dtype(dtype)
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
@@ -132,12 +135,14 @@ class Tensor:
if dtype is None:
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
if dtype == dtypes.bfloat16: data = Tensor(_fromnp(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
else: data = _fromnp(np.array(data).astype(_to_np_dtype(dtype)))
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
else: data = _frompy(data, dtype)
elif data is None: data = _metaop(MetaOps.EMPTY, (0,), dtype or dtypes.default_float, device)
elif isinstance(data, np.ndarray):
elif str(type(data)) == "<class 'numpy.ndarray'>":
import numpy as np
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
elif isinstance(data, pathlib.Path):
dtype = dtype or dtypes.uint8
data = _metaop(MetaOps.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
@@ -295,7 +300,7 @@ class Tensor:
"""
return self.data().tolist()
def numpy(self) -> np.ndarray:
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
"""
Returns the value of this tensor as a `numpy.ndarray`.
@@ -304,6 +309,7 @@ class Tensor:
print(repr(t.numpy()))
```
"""
import numpy as np
if self.dtype == dtypes.bfloat16: return self.float().numpy()
assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
@@ -3141,7 +3147,7 @@ class Tensor:
"""
return (self.maximum(0) - Y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction)
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
"""
Computes the sparse categorical cross-entropy loss between `self` and `Y`.

View File

@@ -10,12 +10,14 @@ from tinygrad.helpers import Context, getenv, to_function_name
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
from tinygrad.engine.graph import uops_colors, word_wrap
from tinygrad.engine.realize import get_runner
from tinygrad.engine.schedule import full_ast_rewrite
from tinygrad.engine.schedule import ScheduleItemContext, full_ast_rewrite
# **** /graph - detailed UOp + rewrites
# NOTE: UPats in ops.py are spec
def graph_rewrites(ctx:TrackedRewriteContext): return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"]
# TODO: fix key for uop with buffer
def graph_rewrites(ctx:TrackedRewriteContext):
return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py" and not ("schedule" in ctx.loc[0] and "DEFINE_GLOBAL" in str(x[2]))]
@dataclass(frozen=True)
class RewriteLocation:
@@ -95,7 +97,8 @@ def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
code = ""
for ctx in contexts:
if ctx.loc[0].split("/")[-1] == "schedule.py":
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src
si_ctx = ScheduleItemContext(bufs=tuple(x.arg for x in ctx.sink.sparents if x.op is UOps.DEFINE_GLOBAL))
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink, si_ctx)).p).name, prg.src
elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, ""
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {})
ret[k].ctxs[(ctx.loc, ctx.sink.key)] = ctx