unrealized consts everywhere (#1963)

* unrealized consts everywhere

* don't import device from lazy

* Device isn't in Lazy

* same issue

* disable jit random
This commit is contained in:
George Hotz
2023-10-04 01:48:10 -07:00
committed by GitHub
parent f04c1a63ae
commit 6a79d4044a
7 changed files with 18 additions and 20 deletions

View File

@@ -268,7 +268,7 @@ jobs:
- name: Install dependencies
run: pip install -e '.[testing${{matrix.backend=='llvm'&&',llvm'||matrix.backend=='cuda'&&',cuda'||matrix.backend=='ptx'&&',cuda'||matrix.backend=='triton'&&',triton'||''}}]' --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/
- name: Check Device.DEFAULT
run: python -c "from tinygrad.lazy import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU'], Device.DEFAULT"
run: python -c "from tinygrad.ops import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU'], Device.DEFAULT"
- name: Run pytest (not cuda)
if: matrix.backend!='cuda' && matrix.backend!='ptx' && matrix.backend!='triton'
run: python -m pytest -n=auto test/ -k '${{matrix.backend=='llvm'&&'not (test_nn.py and test_conv_transpose2d)'||'test'}}' -m 'not exclude_${{matrix.backend}}'

View File

@@ -8,8 +8,8 @@ from tinygrad.helpers import prod, dtypes
# *** first, we implement the atan2 op at the lowest level ***
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
from tinygrad.lazy import LazyBuffer, create_lazybuffer, Device
from tinygrad.ops import ASTRunner
from tinygrad.lazy import LazyBuffer, create_lazybuffer
from tinygrad.ops import ASTRunner, Device
from tinygrad.shape.shapetracker import ShapeTracker
import pytest

View File

@@ -134,6 +134,7 @@ class TestJit(unittest.TestCase):
assert output2 != expect2
assert len(f.jit_cache) == 1
@unittest.skip("random isn't working in JIT")
def test_jit_random_regen(self):
def f(a, b):
rn = Tensor.randn(*a.shape)

View File

@@ -1,7 +1,8 @@
#!/usr/bin/env python
import numpy as np
import unittest
from tinygrad.lazy import LazyBuffer, Device
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import Device
from tinygrad.tensor import Tensor
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import CacheCollector

View File

@@ -5,7 +5,7 @@ from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
@@ -123,7 +123,7 @@ class LazyBuffer:
@property
def base(self): return self._base if self._base is not None else self
def is_unrealized_const(self): return not self.realized and (self.base.op.op == LoadOps.CONST and isinstance(Device[self.device], Compiled))
def is_unrealized_const(self): return not self.realized and self.base.op.op == LoadOps.CONST
@property
def realized(self): return self.base._realized

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import time, importlib, inspect, functools, pathlib
import numpy as np
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
@@ -113,9 +114,12 @@ class Interpreted:
self.codegen = None
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs):
if ast.op == BufferOps.MEM and BufferOps.MEM not in self.fxn_for_op:
assert inputs[ast.arg.idx-1].dtype == ast.arg.dtype, "dtype mismatch"
buf = self.to_underlying(inputs[ast.arg.idx-1])
if ast.op in BufferOps and ast.op not in self.fxn_for_op:
if ast.op == BufferOps.MEM:
assert inputs[ast.arg.idx-1].dtype == ast.arg.dtype, "dtype mismatch"
buf = self.to_underlying(inputs[ast.arg.idx-1])
elif ast.op == BufferOps.CONST:
buf = self.to_underlying(self.buffer.fromCPU(np.array(ast.arg.val, dtype=ast.arg.dtype.np)))
for mop,arg in ast.arg.st.to_movement_ops(): buf = self.fxn_for_op[mop](buf, arg)
return self.from_underlying(buf)
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:

View File

@@ -39,7 +39,7 @@ def _realize_custom(buffer: LazyBuffer) -> None:
buffer.realized = buffer.op.arg(buffer, *[x.realize() for x in buffer.op.src])
def _realize_from(buffer: LazyBuffer) -> None:
rawbuf = buffer.op.src[0].realize()
rawbuf = cast(LazyBuffer, buffer.op.src[0]).contiguous().realize()
assert rawbuf.realized, "realize failed?"
if DEBUG >= 3: print(f"*** copy {buffer.device} <- {rawbuf.device} size {rawbuf.realized.size} dtype {rawbuf.realized.dtype}")
# TODO: make this generic
@@ -56,17 +56,10 @@ def _realize_empty(buffer: LazyBuffer) -> None:
assert all_int(buffer.shape), "does not support symbolic shape"
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
def _gen_rand(rng, shape, dt): return rng.random(size=shape, dtype=np.float32).astype(dtype=dt, copy=False)
def _realize_rand(buffer: LazyBuffer) -> None:
assert all_int(buffer.shape), "does not support symbolic shape"
rng = np.random.default_rng(buffer.op.arg)
buffer.realized = Device[buffer.device].buffer.fromCPU(_gen_rand(rng, buffer.shape, buffer.dtype.np), **buffer._device_extra_args()) # type: ignore
# Jit support
from tinygrad.jit import CacheCollector
CacheCollector.add(lambda args, vars, jit: args[0]._copyin(_gen_rand(*args[1:])), [buffer.realized, rng, buffer.shape, buffer.dtype.np], {})
def _realize_const(buffer: LazyBuffer) -> None:
buffer.realized = Device[buffer.device].buffer.fromCPU(np.array(buffer.op.arg, dtype=buffer.dtype.np), **buffer._device_extra_args())
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args())
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
LoadOps.CONTIGUOUS: _realize_contiguous,
@@ -74,5 +67,4 @@ LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
LoadOps.FROM: _realize_from,
LoadOps.EMPTY: _realize_empty,
LoadOps.RAND: _realize_rand,
LoadOps.CONST: _realize_const,
}