mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
9
.github/workflows/test.yml
vendored
9
.github/workflows/test.yml
vendored
@@ -140,6 +140,15 @@ jobs:
|
||||
python -c "from tinygrad.tensor import Tensor; print(Tensor([1,2,3,4,5]))"
|
||||
pip install mypy
|
||||
mypy -c "from tinygrad.tensor import Tensor; print(Tensor([1,2,3,4,5]))"
|
||||
- name: Run beautiful_mnist without numpy
|
||||
run: |
|
||||
mkdir $HOME/test_no_numpy_dir
|
||||
cd $HOME/test_no_numpy_dir
|
||||
python -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install $GITHUB_WORKSPACE
|
||||
cp $GITHUB_WORKSPACE/examples/beautiful_mnist.py .
|
||||
PYTHONPATH=$GITHUB_WORKSPACE BS=2 STEPS=10 python beautiful_mnist.py
|
||||
- name: Test DEBUG
|
||||
run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
|
||||
- name: Repo line count <9800 lines
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.ops import UOps
|
||||
from tinygrad.device import Compiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
|
||||
from tinygrad.helpers import DEBUG, ansilen, getenv, colored
|
||||
from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
|
||||
def get_sched_resnet():
|
||||
@@ -53,7 +53,7 @@ def get_sched_bert():
|
||||
# ignore grad norm and loss scaler for now
|
||||
loss.backward()
|
||||
targets += [x.lazydata for x in optim.schedule_step()]
|
||||
sched = create_schedule(targets, seen)
|
||||
sched = create_schedule(targets)
|
||||
print(f"schedule length {len(sched)}")
|
||||
return sched
|
||||
|
||||
@@ -128,7 +128,7 @@ if __name__ == "__main__":
|
||||
running_gflops += gflops * tm
|
||||
if (key := str([str(m) for m in si.metadata] if si.metadata is not None else None)) not in usage: usage[key] = (0, 0)
|
||||
usage[key] = (usage[key][0] + tm, usage[key][1] + 1)
|
||||
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(prg.global_size):18s} {str(prg.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS {[str(m) for m in si.metadata] if si.metadata is not None else ''}")
|
||||
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(prg.global_size):18s} {str(prg.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS {[repr(m) if TRACEMETA >= 2 else str(m) for m in si.metadata] if si.metadata is not None else ''}")
|
||||
print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")
|
||||
print("usage:")
|
||||
for k in sorted(usage, key=lambda x: -usage[x][0])[:10]:
|
||||
|
||||
@@ -644,6 +644,13 @@ def train_bert():
|
||||
else:
|
||||
MLLOGGER = None
|
||||
|
||||
# ** init wandb **
|
||||
WANDB = getenv("WANDB")
|
||||
if WANDB:
|
||||
import wandb
|
||||
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
|
||||
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
|
||||
|
||||
# ** hyperparameters **
|
||||
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
|
||||
@@ -672,7 +679,7 @@ def train_bert():
|
||||
|
||||
Tensor.manual_seed(seed) # seed for weight initialization
|
||||
|
||||
model = get_mlperf_bert_model(init_ckpt if not INITMLPERF else None)
|
||||
model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None)
|
||||
|
||||
for _, x in get_state_dict(model).items():
|
||||
x.realize().to_(GPUS)
|
||||
@@ -727,14 +734,8 @@ def train_bert():
|
||||
start_step = int(scheduler_wd.epoch_counter.numpy().item())
|
||||
print(f"resuming from {ckpt} at step {start_step}")
|
||||
|
||||
# ** init wandb **
|
||||
WANDB = getenv("WANDB")
|
||||
if WANDB:
|
||||
import wandb
|
||||
wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
|
||||
wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
|
||||
|
||||
if not INITMLPERF:
|
||||
if RUNMLPERF:
|
||||
# only load real data with RUNMLPERF
|
||||
eval_it = iter(batch_load_val_bert(EVAL_BS))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
|
||||
for _ in range(start_step): next(train_it) # Fast forward
|
||||
@@ -743,10 +744,12 @@ def train_bert():
|
||||
step_times = []
|
||||
# ** train loop **
|
||||
wc_start = time.perf_counter()
|
||||
if INITMLPERF:
|
||||
i, train_data = start_step, get_fake_data_bert(GPUS, BS)
|
||||
else:
|
||||
if RUNMLPERF:
|
||||
# only load real data with RUNMLPERF
|
||||
i, train_data = start_step, get_data_bert(GPUS, train_it)
|
||||
else:
|
||||
i, train_data = start_step, get_fake_data_bert(GPUS, BS)
|
||||
|
||||
while train_data is not None and i < train_steps and not achieved:
|
||||
Tensor.training = True
|
||||
BEAM.value = TRAIN_BEAM
|
||||
@@ -759,10 +762,10 @@ def train_bert():
|
||||
pt = time.perf_counter()
|
||||
|
||||
try:
|
||||
if INITMLPERF:
|
||||
next_data = get_fake_data_bert(GPUS, BS)
|
||||
else:
|
||||
if RUNMLPERF:
|
||||
next_data = get_data_bert(GPUS, train_it)
|
||||
else:
|
||||
next_data = get_fake_data_bert(GPUS, BS)
|
||||
except StopIteration:
|
||||
next_data = None
|
||||
|
||||
@@ -807,10 +810,10 @@ def train_bert():
|
||||
BEAM.value = EVAL_BEAM
|
||||
|
||||
for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
|
||||
if INITMLPERF:
|
||||
eval_data = get_fake_data_bert(GPUS, EVAL_BS)
|
||||
else:
|
||||
if RUNMLPERF:
|
||||
eval_data = get_data_bert(GPUS, eval_it)
|
||||
else:
|
||||
eval_data = get_fake_data_bert(GPUS, EVAL_BS)
|
||||
GlobalCounters.reset()
|
||||
st = time.time()
|
||||
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
|
||||
|
||||
export BEAM=4
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN"
|
||||
export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run.
|
||||
|
||||
export BENCHMARK=10 DEBUG=2
|
||||
|
||||
python3 examples/mlperf/model_train.py
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
|
||||
|
||||
export BEAM=4
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN"
|
||||
export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run.
|
||||
|
||||
export WANDB=1
|
||||
|
||||
python3 examples/mlperf/model_train.py
|
||||
@@ -3,14 +3,11 @@
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export SUBMISSION_PLATFORM="tinybox_green"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
|
||||
|
||||
export BEAM=4
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN"
|
||||
export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run.
|
||||
|
||||
# pip install -e ".[mlperf]"
|
||||
export LOGMLPERF=1
|
||||
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=84 EVAL_BS=6
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
|
||||
|
||||
export BEAM=4
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN"
|
||||
export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run.
|
||||
|
||||
export BENCHMARK=10 DEBUG=2
|
||||
|
||||
python3 examples/mlperf/model_train.py
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=84 EVAL_BS=6
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
|
||||
|
||||
export BEAM=4
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN"
|
||||
export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run.
|
||||
|
||||
export WANDB=1
|
||||
|
||||
python3 examples/mlperf/model_train.py
|
||||
@@ -3,14 +3,11 @@
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export SUBMISSION_PLATFORM="tinybox_red"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=84 EVAL_BS=6
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=54 EVAL_BS=6
|
||||
|
||||
export BEAM=4
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
echo "TODO: DISABLING DROPOUT - UNSET FOR REAL SUBMISSION RUN"
|
||||
export DISABLE_DROPOUT=1 # TODO: Unset flag for real submission run.
|
||||
|
||||
# pip install -e ".[mlperf]"
|
||||
export LOGMLPERF=1
|
||||
|
||||
|
||||
@@ -290,5 +290,5 @@ if __name__ == "__main__":
|
||||
if args.prompt == default_prompt and args.steps == 6 and args.seed == 0 and args.guidance == 7.5:
|
||||
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
|
||||
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
|
||||
assert distance < 45e-5, colored(f"validation failed with {distance=}", "red")
|
||||
assert distance < 50e-5, colored(f"validation failed with {distance=}", "red")
|
||||
print(colored(f"output validated with {distance=}", "green"))
|
||||
|
||||
@@ -49,15 +49,15 @@ class BertForPretraining:
|
||||
output = self.bert(input_ids, attention_mask, token_type_ids)
|
||||
return self.cls(output, masked_lm_positions)
|
||||
|
||||
def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
# Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315
|
||||
def sparse_categorical_crossentropy(predictions:Tensor, labels:Tensor, ignore_index=-1):
|
||||
log_probs, loss_mask = predictions.log_softmax(), (labels != ignore_index)
|
||||
y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
|
||||
y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1])
|
||||
return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero
|
||||
# Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315
|
||||
def sparse_categorical_crossentropy(self, predictions:Tensor, labels:Tensor, ignore_index=-1):
|
||||
log_probs, loss_mask = predictions.log_softmax(), (labels != ignore_index)
|
||||
y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
|
||||
y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1])
|
||||
return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero
|
||||
|
||||
masked_lm_loss = sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)
|
||||
def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)
|
||||
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
|
||||
return masked_lm_loss + next_sentence_loss
|
||||
|
||||
@@ -66,7 +66,7 @@ class BertForPretraining:
|
||||
valid = masked_lm_ids != 0
|
||||
masked_lm_predictions = prediction_logits.log_softmax().argmax(-1)
|
||||
masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid
|
||||
masked_lm_loss = prediction_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
|
||||
masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)
|
||||
|
||||
seq_relationship_predictions = seq_relationship_logits.log_softmax().argmax(-1)
|
||||
seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels)
|
||||
|
||||
6
setup.py
6
setup.py
@@ -21,7 +21,7 @@ setup(name='tinygrad',
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
],
|
||||
install_requires=["numpy"],
|
||||
install_requires=[],
|
||||
python_requires='>=3.8',
|
||||
extras_require={
|
||||
'llvm': ["llvmlite"],
|
||||
@@ -37,6 +37,7 @@ setup(name='tinygrad',
|
||||
],
|
||||
#'mlperf': ["mlperf-logging @ git+https://github.com/mlperf/logging.git@4.0.0-rc2"],
|
||||
'testing': [
|
||||
"numpy",
|
||||
"torch",
|
||||
"pillow",
|
||||
"pytest",
|
||||
@@ -63,7 +64,8 @@ setup(name='tinygrad',
|
||||
"mkdocstrings[python]",
|
||||
"markdown-callouts",
|
||||
"markdown-exec[ansi]",
|
||||
"black"
|
||||
"black",
|
||||
"numpy",
|
||||
],
|
||||
'testing_tf': [
|
||||
"tensorflow==2.15.1",
|
||||
|
||||
4
test/external/external_model_benchmark.py
vendored
4
test/external/external_model_benchmark.py
vendored
@@ -74,7 +74,9 @@ def benchmark_model(m, devices, validate_outs=False):
|
||||
del inputs, tinygrad_model, tinygrad_jitted_model
|
||||
except RuntimeError as e:
|
||||
# TODO: we don't run the dm model on METAL for now
|
||||
if Device.DEFAULT == "METAL": assert "buffer count limit" in str(e)
|
||||
if Device.DEFAULT == "METAL":
|
||||
assert "buffer count limit" in str(e)
|
||||
return
|
||||
else: raise e
|
||||
|
||||
# convert model to torch
|
||||
|
||||
@@ -578,7 +578,7 @@ class TestAutoCastType(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
|
||||
|
||||
@given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_int(d) and is_dtype_supported(d)]))
|
||||
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
|
||||
def test_int_to_float_unary_func(self, dtype):
|
||||
for func in [
|
||||
lambda t: t.exp(),
|
||||
|
||||
@@ -477,7 +477,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_double_from(self):
|
||||
x = Tensor([1,2,3,4])
|
||||
out = x.to('npy')
|
||||
out = x.to('python')
|
||||
check_schedule(out, 0, filter_sink=False)
|
||||
|
||||
def test_pow_const_tensor_simplified(self):
|
||||
|
||||
@@ -388,11 +388,6 @@ class TestUOpStr(unittest.TestCase):
|
||||
sink = UOp(UOps.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
|
||||
assert_equiv_uops(sink, eval(str(sink)))
|
||||
|
||||
def test_variable_const(self):
|
||||
# TODO: this is not possible after VALID.
|
||||
uop = UOp(UOps.CONST, dtypes.int, (), arg=Variable("a",1,10))
|
||||
assert str(eval(str(uop))) == str(uop)
|
||||
|
||||
def test_vectorized_str(self):
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.int.vec(4), tuple(UOp.const(dtypes.int, x) for x in range(4)))
|
||||
assert str(eval(str(vec))) == str(vec)
|
||||
|
||||
@@ -94,6 +94,7 @@ class TestUOpsStats(unittest.TestCase):
|
||||
# NOTE; ops also include indexing ops
|
||||
assert expected_ops <= ops and ops <= expected_ops * 2
|
||||
|
||||
@unittest.skipIf(getenv("PTX"), "wrong in PTX")
|
||||
def test_simple_add_sq(self):
|
||||
a = Tensor.empty(100,100)
|
||||
b = Tensor.empty(100,100)
|
||||
|
||||
@@ -3,6 +3,7 @@ from PIL import Image
|
||||
from tinygrad.helpers import Context, ContextVar
|
||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
import numpy as np
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
|
||||
@@ -188,6 +189,12 @@ class TestFullyFlatten(unittest.TestCase):
|
||||
self.assertEqual(fully_flatten([[1, 2, [3, 4]], [5, 6], 7]), [1, 2, 3, 4, 5, 6, 7])
|
||||
self.assertEqual(fully_flatten([[1, "ab"], [True, None], [3.14, [5, "b"]]]), [1, "ab", True, None, 3.14, 5, "b"])
|
||||
|
||||
def test_fully_flatten_numpy(self):
|
||||
self.assertEqual(fully_flatten([np.array([1, 3]), np.array([1, 2])]), [1, 3, 1, 2])
|
||||
self.assertEqual(fully_flatten((np.array([1, 3]), np.array([1, 2]))), [1, 3, 1, 2])
|
||||
self.assertEqual(fully_flatten([np.array([[1], [3]]), np.array([[1], [2]])]), [1, 3, 1, 2])
|
||||
self.assertEqual(fully_flatten([[1, "ab"], [True, None], np.array([[3.14], [6.28]])]), [1, "ab", True, None, 3.14, 6.28])
|
||||
|
||||
class TestMemoryview(unittest.TestCase):
|
||||
def test_from_mv_to_mv(self):
|
||||
base = memoryview(bytearray(b"\x11\x22\x33"*40))
|
||||
|
||||
@@ -4,7 +4,7 @@ import functools, itertools, heapq, math, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element
|
||||
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite
|
||||
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite, TernaryOps
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
||||
@@ -266,12 +266,30 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
|
||||
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid else (buf, idx)))
|
||||
|
||||
# ***** transcendental *****
|
||||
# ***** optional patterns *****
|
||||
|
||||
transcendental_patterns = [
|
||||
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.EXP2), xexp2),
|
||||
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.LOG2), xlog2),
|
||||
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.SIN), xsin),
|
||||
]
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def transcendental_folding(ops):
|
||||
return PatternMatcher([(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=k), cast(Callable, v))
|
||||
for k,v in ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if k not in ops])
|
||||
def get_extra_patterns(ops, force_transcendental=False):
|
||||
pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental]
|
||||
if BinaryOps.SHL in ops and BinaryOps.SHR in ops:
|
||||
shiftable_consts = set([2**i for i in range(64)])
|
||||
pat += [
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MUL, name="root", dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda root, mul, const:
|
||||
UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", src=(UPat.var("div"), UPat.cvar("const"))), lambda root, div, const:
|
||||
UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None)]
|
||||
if UnaryOps.NEG in ops:
|
||||
pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
|
||||
if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))]
|
||||
if TernaryOps.MULACC in ops:
|
||||
pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))]
|
||||
return PatternMatcher(pat)
|
||||
|
||||
# ***** threefry *****
|
||||
|
||||
@@ -348,9 +366,10 @@ def no_vectorized_wmma(wmma:UOp):
|
||||
|
||||
# this is symbolic 2.0
|
||||
constant_folder = PatternMatcher([
|
||||
# bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.ADD, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)),
|
||||
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
|
||||
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),
|
||||
(UPat.var('x', dtype=dtypes.bool).max(UPat.var('y')), lambda x,y: x|y),
|
||||
# self ASSIGN is just self
|
||||
(UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
||||
# ASSIGN to global is just self
|
||||
@@ -715,11 +734,10 @@ linearize_cnt = 0
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
global linearize_cnt, acc_number
|
||||
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
folder = constant_folder + transcendental_folding(tuple() if TRANSCENDENTAL >= 2 or opts is None else tuple(opts.code_for_op.keys()))
|
||||
|
||||
# do graph rewrite
|
||||
acc_number = 0
|
||||
sink = graph_rewrite(sink, folder)
|
||||
sink = graph_rewrite(sink, constant_folder)
|
||||
|
||||
# rewrite pyint to int32
|
||||
sink = graph_rewrite(sink, no_pyint)
|
||||
@@ -727,11 +745,12 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
# expand
|
||||
linearize_cnt += 1
|
||||
if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1:
|
||||
sink = graph_rewrite(sink, folder+expander)
|
||||
sink = graph_rewrite(sink, constant_folder+expander)
|
||||
if getenv("DO_REDUCE", 1):
|
||||
sink = graph_rewrite(sink, folder+just_reduce)
|
||||
sink = graph_rewrite(sink, folder+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
|
||||
sink = graph_rewrite(sink, folder+reducer)
|
||||
sink = graph_rewrite(sink, constant_folder+just_reduce)
|
||||
sink = graph_rewrite(sink, constant_folder+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
|
||||
sink = graph_rewrite(sink, constant_folder+reducer)
|
||||
sink = graph_rewrite(sink, constant_folder+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2))
|
||||
|
||||
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, opts.extra_matcher)
|
||||
return sink
|
||||
|
||||
@@ -27,7 +27,7 @@ class ImageDType(DType):
|
||||
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
|
||||
base: DType
|
||||
local: bool = False # images are never local
|
||||
def scalar(self): return self.base
|
||||
def scalar(self) -> DType: return self.base
|
||||
def vec(self, sz:int): return self.base.vec(sz)
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
@@ -44,13 +44,13 @@ class PtrDType(DType):
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def is_float(x: DType) -> bool: return x.scalar() in {dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64}
|
||||
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats
|
||||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
@functools.lru_cache(None)
|
||||
def is_int(x: DType) -> bool: return x.scalar() in {dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.pyint} or dtypes.is_unsigned(x)
|
||||
def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def is_unsigned(x: DType) -> bool: return x.scalar() in {dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64}
|
||||
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
|
||||
@staticmethod
|
||||
def from_py(x) -> DType:
|
||||
if x.__class__ is float: return dtypes.default_float
|
||||
@@ -114,6 +114,11 @@ class dtypes:
|
||||
default_float: ClassVar[DType] = float32
|
||||
default_int: ClassVar[DType] = int32
|
||||
|
||||
floats = (float16, bfloat16, float32, float64)
|
||||
uints = (uint8, uint16, uint32, uint64)
|
||||
sints = (int8, int16, int32, int64, pyint)
|
||||
ints = uints + sints
|
||||
|
||||
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
||||
dtypes.default_float = getattr(dtypes, env_default_float.lower())
|
||||
assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
|
||||
@@ -137,7 +142,8 @@ def least_upper_dtype(*ds:DType) -> DType:
|
||||
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
|
||||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void')) or v.__class__ is staticmethod)}
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void'))
|
||||
or v.__class__ is staticmethod or isinstance(v, tuple))}
|
||||
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
|
||||
INVERSE_DTYPES_DICT['void'] = 'void'
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast, get_args
|
||||
from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOps, UOp, UOps, PatternMatcher, UPat, graph_rewrite
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
|
||||
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
@@ -42,6 +41,10 @@ class LBScheduleItem:
|
||||
@property
|
||||
def inputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[len(self.ast.src):] if self.ast.op is UOps.SINK else self.bufs[1:]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
bufs: Tuple[Buffer, ...]
|
||||
|
||||
# *** UOp with SWIZZLE (movementops) rewriting to UOp we can index ***
|
||||
|
||||
# ** helpers for doing movementops on uops
|
||||
@@ -113,9 +116,14 @@ reduceop_fusor = PatternMatcher([
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
def full_ast_rewrite(sink:UOp) -> UOp:
|
||||
enumerate_bufs = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda ctx,x: x.replace(arg=ctx.bufs.index(x.arg)) if isinstance(x.arg, Buffer) else None),
|
||||
])
|
||||
|
||||
def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
if not AST_REWRITE: return sink
|
||||
return graph_rewrite(sink, reduceop_fusor)
|
||||
sink = graph_rewrite(sink, reduceop_fusor)
|
||||
return graph_rewrite(sink, enumerate_bufs, ctx)
|
||||
|
||||
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
||||
|
||||
@@ -146,8 +154,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
if buf not in assign_targets and buf not in inputs: inputs.append(buf)
|
||||
ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
||||
outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.index(buf))
|
||||
ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), buf.buffer)
|
||||
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
|
||||
|
||||
# reduce ops change ShapeTracker
|
||||
@@ -175,16 +182,16 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
||||
cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {}
|
||||
ast: List[UOp] = []
|
||||
inputs: List[LazyBuffer] = []
|
||||
for i, out in enumerate(outs):
|
||||
for out in outs:
|
||||
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
|
||||
if out.op is MetaOps.ASSIGN and out.arg:
|
||||
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
|
||||
output_st = out.arg[0]
|
||||
output_st, vv = output_st.simplify().unbind()
|
||||
var_vals.update(vv)
|
||||
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
|
||||
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), out.buffer)
|
||||
ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src)))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(x.buffer for x in outs+inputs)))
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
@@ -194,7 +201,9 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
|
||||
double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None:
|
||||
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
|
||||
if buf in allbufs or buf.base.realized is not None: return
|
||||
if GRAPH: log_lazybuffer(buf, scheduled)
|
||||
if GRAPH:
|
||||
from tinygrad.engine.graph import log_lazybuffer
|
||||
log_lazybuffer(buf, scheduled)
|
||||
# check if we need to realize views
|
||||
if buf is not buf.base:
|
||||
# fuse some pads
|
||||
@@ -410,6 +419,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
lsi = queue.popleft()
|
||||
if GRAPH:
|
||||
kernel_number += 1
|
||||
from tinygrad.engine.graph import realized_lazybuffer
|
||||
for out in lsi.outputs: realized_lazybuffer(out, kernel_number)
|
||||
for out in lsi.outputs: del out.srcs # can only schedule once
|
||||
schedule.append(si:=ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.bufs if x.size != 0), lsi.metadata))
|
||||
|
||||
@@ -32,7 +32,14 @@ def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
||||
def ansilen(s:str): return len(ansistrip(s))
|
||||
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
||||
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
|
||||
def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist])]
|
||||
def fully_flatten(l):
|
||||
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
|
||||
flattened = []
|
||||
if hasattr(l, "shape") and l.shape == (): flattened.append(l[()])
|
||||
else:
|
||||
for i in range(len(l)): flattened.extend(fully_flatten(l[i]))
|
||||
return flattened
|
||||
return [l]
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
||||
def round_up(num, amt:int): return (num+amt-1)//amt * amt
|
||||
@@ -62,10 +69,12 @@ def get_child(obj, key):
|
||||
return obj
|
||||
|
||||
def get_shape(x) -> Tuple[int, ...]:
|
||||
if not isinstance(x, (list, tuple)): return ()
|
||||
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str): return ()
|
||||
if (aapi := (hasattr(x, "shape") and x.shape == ())): return ()
|
||||
subs = [get_shape(xi) for xi in x]
|
||||
if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
|
||||
return (len(subs),) + (subs[0] if subs else ())
|
||||
slen = 1 if aapi else len(subs)
|
||||
return (slen,) + (subs[0] if subs else ())
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
||||
|
||||
@@ -16,7 +16,7 @@ def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||||
"""
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:8].bitcast(dtypes.int64).item()
|
||||
return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
|
||||
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar
|
||||
import sys, time, functools, itertools, math, operator, hashlib
|
||||
from enum import auto, IntEnum, Enum
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import _CURRENT_KERNEL, ContextVar, pretty_print, prod, getenv, all_same
|
||||
@@ -21,11 +20,11 @@ class FastEnum(IntEnum):
|
||||
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
|
||||
class UnaryOps(FastEnum):
|
||||
"""A -> A (elementwise)"""
|
||||
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto() # noqa: E702
|
||||
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
class BinaryOps(FastEnum):
|
||||
"""A + A -> A (elementwise)"""
|
||||
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702
|
||||
class TernaryOps(FastEnum):
|
||||
"""A + A + A -> A (elementwise)"""
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
@@ -170,8 +169,7 @@ class UOp(MathTrait):
|
||||
def key(self) -> bytes:
|
||||
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self):
|
||||
return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else repr(self.arg) if isinstance(self.arg, Variable) else self.arg
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
|
||||
# *** uop syntactic sugar
|
||||
@property
|
||||
def st_arg(self) -> ShapeTracker:
|
||||
@@ -304,7 +302,8 @@ python_alu: Dict[Op, Callable] = {
|
||||
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
|
||||
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
||||
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
||||
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
|
||||
UnaryOps.NEG: operator.neg, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub,
|
||||
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul,
|
||||
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
|
||||
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
|
||||
@@ -412,7 +411,7 @@ class UPat(MathTrait):
|
||||
return UPat((UOps.CONST, UOps.VCONST) if vec else UOps.CONST, dtype=dtype, name=name)
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def const(dtype:Optional[DType], b:ConstType|Variable): return UPat(UOps.CONST, dtype=dtype, arg=b)
|
||||
def const(dtype:Optional[DType], b:ConstType): return UPat(UOps.CONST, dtype=dtype, arg=b)
|
||||
|
||||
# copied from UOp
|
||||
def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,))
|
||||
@@ -442,7 +441,7 @@ class UPat(MathTrait):
|
||||
|
||||
def match(self:UPat, uop:UOp, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
|
||||
if (self.name is not None and store.setdefault(self.name, uop) is not uop) or \
|
||||
(self.dtype is not None and uop.dtype not in self.dtype) or \
|
||||
(self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
|
||||
(self.arg is not None and self.arg != uop.arg) or \
|
||||
(self.op is not None and uop.op not in self.op) or \
|
||||
(self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
|
||||
@@ -465,18 +464,19 @@ class UPatAny(UPat):
|
||||
class PatternMatcher:
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
self.patterns = patterns
|
||||
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = defaultdict(list)
|
||||
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
|
||||
self.pdict: Dict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = {}
|
||||
# uop is required, arg is optional
|
||||
for p,fxn in self.patterns:
|
||||
assert p.op is not None
|
||||
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn, p.early_reject))
|
||||
for uop in p.op: self.pdict.setdefault((uop, p.arg), []).append((p, fxn, p.early_reject))
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
|
||||
for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])):
|
||||
if not early_reject.issubset(ler): continue
|
||||
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret
|
||||
return None
|
||||
@@ -501,7 +501,7 @@ class TrackedPatternMatcher(PatternMatcher):
|
||||
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
||||
ret = None
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
|
||||
for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])):
|
||||
st = time.perf_counter()
|
||||
if not early_reject.issubset(ler):
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
||||
import struct, math
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat
|
||||
from tinygrad.codegen.uopgraph import constant_folder
|
||||
@@ -33,27 +33,15 @@ asm_for_op: Dict[Op, Callable] = {
|
||||
}
|
||||
|
||||
supports_half: List[Op] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
shiftable_consts = set([2**i for i in range(64)])
|
||||
ptx_matcher = constant_folder+PatternMatcher([
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MUL, name="root", dtype=tuple([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
src=[UPat.cvar("const"), UPat.var("mul")]),
|
||||
lambda root, mul, const: UOp(UOps.ALU, root.dtype,
|
||||
(mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="root", dtype=tuple([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
src=[UPat.cvar("const"), UPat.var("div")]),
|
||||
lambda root, div, const: UOp(UOps.ALU, root.dtype,
|
||||
(div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.CMPNE, src=(UPat(dtype=dtypes.bool),UPat()), name="root"),
|
||||
lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.CMPLT, src=(UPat.var("x", dtype=dtypes.bool),UPat.var("y")), name="root"),
|
||||
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x, UOp.const(dtypes.bool, True)), BinaryOps.CMPNE), y), BinaryOps.MUL)),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat.var("non_muls"), UPat(UOps.ALU, arg=BinaryOps.MUL, name="muls")], name="root"),
|
||||
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
|
||||
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
||||
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
||||
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
|
||||
# upcast to float32 all the ops that don't support half
|
||||
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple([vv.cast(dtypes.float32) for vv in x.src]), x.arg).cast(dtypes.half)))
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
|
||||
for op in asm_for_op.keys() if op not in supports_half],
|
||||
(UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX),
|
||||
lambda x: UOp(UOps.ALU, dtypes.uint8, tuple(s.cast(dtypes.uint8) for s in x.src), x.arg).cast(dtypes.bool)),
|
||||
# fix the gates for load/store (low quality!)
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat.var("x"),UPat.var("y"),UPat.var("z"),UPat.var("k"))),
|
||||
lambda root,x,y,z,k: UOp(root.op, dtypes.uint8, (x,y,z.cast(dtypes.uint8),k)).cast(dtypes.bool)),
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
|
||||
@@ -74,8 +62,7 @@ ptx_matcher = constant_folder+PatternMatcher([
|
||||
lambda root, const: UOp(root.op, root.dtype,
|
||||
(root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])),
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)),
|
||||
UPat.var("alu"))), # no const here
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="root", allow_any_len=True, src=(UPat((UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL)), UPat.var("alu"))),
|
||||
lambda root, alu: UOp(root.op, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, 0))+root.src[2:])),
|
||||
|
||||
@@ -8,13 +8,8 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
|
||||
def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str:
|
||||
sidx = strip_parens(r[load.src[1]])
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
assert load.dtype == dtypes.float.vec(4), f"images must be float4, getting {load.dtype}"
|
||||
val = f"read_imagef({r[buf]}, smp, {sidx})"
|
||||
elif r.uses_vload and buf.dtype.scalar() == dtypes.float16 and load.dtype.scalar() != dtypes.float16:
|
||||
val = f"vload_half{'' if load.dtype.count == 1 else str(load.dtype.count)}(0, {r[buf]}+{sidx})"
|
||||
elif load.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
sidx = strip_parens(r[load.src[1]]) if load.src[1].arg == BinaryOps.ADD else r[load.src[1]]
|
||||
if load.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
val = f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(load.dtype)}*)({r[buf]}+{sidx}))"
|
||||
else:
|
||||
val = f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]"
|
||||
@@ -23,19 +18,15 @@ def render_load(r:CStyleLanguage, load:UOp, buf:UOp) -> str:
|
||||
if len(load.src) > 3 and load.src[3].op is UOps.ALU: val = r.code_for_op[TernaryOps.WHERE](r[load.src[3]], val, r[load.src[2]], load.dtype)
|
||||
return val
|
||||
|
||||
def render_store(r:CStyleLanguage, store:UOp, buf:UOp, var:UOp) -> str:
|
||||
sidx = strip_parens(r[store.src[1]])
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
assert var.dtype == dtypes.float.vec(4), f"images must be float4, getting {var.dtype}"
|
||||
val = f"write_imagef({r[buf]}, {sidx}, {r[var]});"
|
||||
elif r.uses_vload and buf.dtype.scalar() == dtypes.float16 and var.dtype.scalar() != dtypes.float16:
|
||||
val = f"vstore_half{'' if var.dtype.count == 1 else str(var.dtype.count)}({r[var]}, 0, {r[buf]}+{sidx});"
|
||||
elif var.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
def render_store(r:CStyleLanguage, buf:UOp, idx:UOp, var:UOp, gate:Optional[UOp]=None) -> str:
|
||||
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
|
||||
if var.dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
prefix = r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix
|
||||
val = f"*(({prefix}{r.render_dtype(var.dtype)}*)({r[buf]}+{sidx})) = {r[var]};"
|
||||
else:
|
||||
val = f"*({r[buf]}+{sidx}) = {r[var]};" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}] = {r[var]};"
|
||||
return f"if ({r[store.src[3]]}) {{ {val} }}" if len(store.src) > 3 and store.src[3].op is not UOps.IF else val
|
||||
# TODO: this if should be in UOps, not here
|
||||
return f"if ({r[gate]}) {{ {val} }}" if gate is not None else val
|
||||
|
||||
def render_alu(r:CStyleLanguage, x:UOp):
|
||||
if x.arg in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == x.arg else r[v] for v in x.src]
|
||||
@@ -54,6 +45,16 @@ base_pm = PatternMatcher([
|
||||
(UPat(UOps.IF, name="x"), lambda r,x: f"if ({r[x.src[0]]}) {{"),
|
||||
(UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda r: "}"),
|
||||
(UPat(UOps.WMMA, name="x"), lambda r,x: f"__{x.arg[0]}({r[x.src[0]]}, {r[x.src[1]]}, {r[x.src[2]]})"),
|
||||
# load/store image
|
||||
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda r,buf,idx,var,gate: f"({r[gate]}?read_imagef({r[buf]}, smp, {r[idx]}):{r[var]})"),
|
||||
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)))),
|
||||
lambda r,buf,idx: f"read_imagef({r[buf]}, smp, {r[idx]})"),
|
||||
# TODO: this if should be in UOps, not here
|
||||
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4)),
|
||||
UPat.var("gate", dtype=dtypes.bool))), lambda r,buf,idx,var,gate: f"if ({r[gate]}) {{ write_imagef({r[buf]}, {r[idx]}, {r[var]}); }}"),
|
||||
(UPat(UOps.STORE, src=(UPat.var('buf'), UPat.var('idx', dtype=dtypes.int.vec(2)), UPat.var("var", dtype=dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda r,buf,idx,var: f"write_imagef({r[buf]}, {r[idx]}, {r[var]});"),
|
||||
# r method accesses
|
||||
(UPat(UOps.RANGE, name="x"), lambda r,x: f"for ({r.render_dtype(x.dtype)} {r[x]} = {r[x.src[0]]}; {r[x]} < {r[x.src[1]]}; {r[x]}++) {{"),
|
||||
(UPat(UOps.VECTORIZE, name="x"),
|
||||
@@ -77,7 +78,8 @@ base_pm = PatternMatcher([
|
||||
(UPat(UOps.CONST, name="x"), lambda r,x: str(x.arg)),
|
||||
# function calls
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"),), allow_any_len=True, name="load"), render_load),
|
||||
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat(), UPat.var("var")), allow_any_len=True, name="store"), render_store),
|
||||
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var"), UPat.var("gate", dtype=dtypes.bool))), render_store),
|
||||
(UPat(UOps.STORE, src=(UPat.var("buf"), UPat.var('idx'), UPat.var("var")), allow_any_len=True), render_store),
|
||||
(UPat(UOps.ALU, name="x"), render_alu),
|
||||
(UPat(UOps.GEP, name="x"), render_gep),
|
||||
])
|
||||
@@ -104,7 +106,6 @@ class CStyleLanguage(Renderer):
|
||||
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
|
||||
extra_args: List[str] = []
|
||||
float4: Optional[str] = None
|
||||
uses_vload: bool = False
|
||||
uses_ptr_arithmetic: bool = False
|
||||
type_map: Dict[DType, str] = {}
|
||||
infinity: str = "INFINITY"
|
||||
@@ -112,8 +113,10 @@ class CStyleLanguage(Renderer):
|
||||
code_for_op: Dict = {
|
||||
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
|
||||
UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
|
||||
UnaryOps.NEG: lambda x,dtype: f"-{x}",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
|
||||
BinaryOps.SHL: lambda a,b,dtype: f"({a}<<{b})", BinaryOps.SHR: lambda a,b,dtype: f"({a}>>{b})",
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
|
||||
BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
||||
BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
|
||||
@@ -240,7 +243,6 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
|
||||
float4 = "(float4)"
|
||||
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
|
||||
uses_vload = True
|
||||
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
|
||||
def render_cast(self, x, var_dtype, bitcast=False) -> str:
|
||||
return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
@@ -4,7 +4,6 @@ import time, math, itertools, functools, struct, sys, inspect, pathlib, string,
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Literal
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
|
||||
@@ -44,10 +43,14 @@ def _metaop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str,
|
||||
if isinstance(device, str): return LazyBuffer.metaop(op, shape, dtype, device, arg, src)
|
||||
return MultiLazyBuffer([LazyBuffer.metaop(op, shape, dtype, d, arg, src) for d in device], None)
|
||||
|
||||
def _from_np_dtype(npdtype:np.dtype) -> DType: return dtypes.fields()[np.dtype(npdtype).name]
|
||||
def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
||||
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
||||
import numpy as np
|
||||
return dtypes.fields()[np.dtype(npdtype).name]
|
||||
def _to_np_dtype(dtype:DType) -> Optional[type]:
|
||||
import numpy as np
|
||||
return np.dtype(dtype.fmt).type if dtype.fmt is not None else None
|
||||
|
||||
def _fromnp(x: np.ndarray) -> LazyBuffer:
|
||||
def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noqa: F821
|
||||
ret = LazyBuffer.metaop(MetaOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
||||
# fake realize
|
||||
ret.buffer.allocate(x)
|
||||
@@ -62,7 +65,7 @@ def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
|
||||
truncate_function = truncate[dtype]
|
||||
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
|
||||
# fake realize
|
||||
ret.buffer.allocate(memoryview(data))
|
||||
ret.buffer.allocate(memoryview(data if Device.DEFAULT != "PYTHON" else bytearray(data)))
|
||||
del ret.srcs
|
||||
return ret
|
||||
|
||||
@@ -106,7 +109,7 @@ class Tensor:
|
||||
training: ClassVar[bool] = False
|
||||
no_grad: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer, Variable, pathlib.Path],
|
||||
def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, 'np.ndarray', bytes, MultiLazyBuffer, Variable, pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
||||
if dtype is not None: dtype = to_dtype(dtype)
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
@@ -132,12 +135,14 @@ class Tensor:
|
||||
if dtype is None:
|
||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
||||
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
||||
if dtype == dtypes.bfloat16: data = Tensor(_fromnp(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
|
||||
else: data = _fromnp(np.array(data).astype(_to_np_dtype(dtype)))
|
||||
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
|
||||
else: data = _frompy(data, dtype)
|
||||
elif data is None: data = _metaop(MetaOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
||||
elif isinstance(data, np.ndarray):
|
||||
elif str(type(data)) == "<class 'numpy.ndarray'>":
|
||||
import numpy as np
|
||||
assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}"
|
||||
if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item())
|
||||
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data)
|
||||
else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined]
|
||||
elif isinstance(data, pathlib.Path):
|
||||
dtype = dtype or dtypes.uint8
|
||||
data = _metaop(MetaOps.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
|
||||
@@ -295,7 +300,7 @@ class Tensor:
|
||||
"""
|
||||
return self.data().tolist()
|
||||
|
||||
def numpy(self) -> np.ndarray:
|
||||
def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
|
||||
"""
|
||||
Returns the value of this tensor as a `numpy.ndarray`.
|
||||
|
||||
@@ -304,6 +309,7 @@ class Tensor:
|
||||
print(repr(t.numpy()))
|
||||
```
|
||||
"""
|
||||
import numpy as np
|
||||
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
||||
assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
@@ -3141,7 +3147,7 @@ class Tensor:
|
||||
"""
|
||||
return (self.maximum(0) - Y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction)
|
||||
|
||||
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
|
||||
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index:int=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
|
||||
"""
|
||||
Computes the sparse categorical cross-entropy loss between `self` and `Y`.
|
||||
|
||||
|
||||
@@ -10,12 +10,14 @@ from tinygrad.helpers import Context, getenv, to_function_name
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
|
||||
from tinygrad.engine.graph import uops_colors, word_wrap
|
||||
from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.engine.schedule import full_ast_rewrite
|
||||
from tinygrad.engine.schedule import ScheduleItemContext, full_ast_rewrite
|
||||
|
||||
# **** /graph - detailed UOp + rewrites
|
||||
|
||||
# NOTE: UPats in ops.py are spec
|
||||
def graph_rewrites(ctx:TrackedRewriteContext): return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"]
|
||||
# TODO: fix key for uop with buffer
|
||||
def graph_rewrites(ctx:TrackedRewriteContext):
|
||||
return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py" and not ("schedule" in ctx.loc[0] and "DEFINE_GLOBAL" in str(x[2]))]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RewriteLocation:
|
||||
@@ -95,7 +97,8 @@ def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
|
||||
code = ""
|
||||
for ctx in contexts:
|
||||
if ctx.loc[0].split("/")[-1] == "schedule.py":
|
||||
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src
|
||||
si_ctx = ScheduleItemContext(bufs=tuple(x.arg for x in ctx.sink.sparents if x.op is UOps.DEFINE_GLOBAL))
|
||||
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink, si_ctx)).p).name, prg.src
|
||||
elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, ""
|
||||
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {})
|
||||
ret[k].ctxs[(ctx.loc, ctx.sink.key)] = ctx
|
||||
|
||||
Reference in New Issue
Block a user