Disable logging in early compile2 and lower kernel counts (#2090)

* Revert "Revert "openpilot kernel fix from 209 to 207 (#2006)" (#2065)"

This reverts commit 924ecc4d6a.

* gate behind OPT >= 4

* disable_logging in schedule

* simple

* from master

* more images

* revert that

* 206 kernels
This commit is contained in:
George Hotz
2023-10-16 20:15:24 -07:00
committed by GitHub
parent 442a27db8a
commit 5a4a62ecae
5 changed files with 19 additions and 10 deletions

View File

@@ -154,7 +154,7 @@ jobs:
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot model compile and size
run: |
DEBUG=2 ALLOWED_KERNEL_COUNT=209 VALIDTEST=1 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py
DEBUG=2 ALLOWED_KERNEL_COUNT=206 VALIDTEST=1 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot model correctness (float32)

View File

@@ -15,9 +15,10 @@ from extra.utils import fetch
from extra.onnx import get_run_onnx
from tinygrad.graph import print_tree
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, DEBUG, getenv
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, DEBUG, getenv, ImageDType
from tinygrad.realize import run_schedule
from tinygrad.ops import LoadOps, Device, ScheduleItem
from tinygrad.features.image import fix_schedule_for_images
Device.DEFAULT = "GPU"
def get_schedule(fn:str) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
@@ -63,9 +64,16 @@ def lb_to_numbers(schedule):
if __name__ == "__main__":
schedule, schedule_independent = get_schedule(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL)
run_schedule(schedule_independent)
run_schedule(schedule_independent, disable_logging=True)
schedule = fix_schedule_for_images(schedule)
print("**** running real kernels ****")
image_count = 0
for si in schedule:
if isinstance(si.out.dtype, ImageDType):
image_count += 1
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
with Context(DEBUG=2, BEAM=getenv("LATEBEAM")):
GlobalCounters.reset()
run_schedule(schedule)

View File

@@ -64,7 +64,7 @@ class TestInferenceMinKernels(unittest.TestCase):
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
img = Tensor.randn(1, 3, 224, 224)
# TODO: this seems very high
with CLCache(116):
with CLCache(115):
model.forward(img).realize()
def test_resnet(self):
@@ -78,7 +78,7 @@ class TestInferenceMinKernels(unittest.TestCase):
model = ViT(embed_dim=192, num_heads=3)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
img = Tensor.randn(1, 3, 224, 224)
with CLCache(223): # NOTE: this is way too high
with CLCache(222): # NOTE: this is way too high
out = model.forward(img)
assert len(CacheCollector.cache) == 0, "ViT prerealized?"
out.realize()
@@ -88,7 +88,7 @@ class TestInferenceMinKernels(unittest.TestCase):
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
with CLCache(94):
with CLCache(85):
model(Tensor([[1,2,3,4]]), 0).realize()
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")

View File

@@ -22,6 +22,7 @@ LAZYCACHE = getenv("LAZYCACHE", 1)
REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
PUSH_RESHAPES = OPT>=4
# **** ast fixing functions ****
@@ -248,7 +249,7 @@ class LazyBuffer:
def _movement_op(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and not self.realized and self.optype == BinaryOps and not self.children:
if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps):
if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and (self.op.op in UnaryOps or PUSH_RESHAPES)):
return self.op.replace_with_movement_ops([(op, arg)])
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root

View File

@@ -9,14 +9,14 @@ from tinygrad.runtime.lib import RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.features.image import fix_schedule_for_images
def run_schedule(schedule:List[ScheduleItem]):
def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
# HACK: images can be not usable due to shape
if IMAGE >= 2: schedule = fix_schedule_for_images(schedule)
# NOTE: if you for loop the schedule it's slow because nothing frees
while len(schedule):
si = schedule.pop(0)
log_schedule_item(si)
if not disable_logging: log_schedule_item(si)
assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized"
if DEBUG >= 3: print_tree(si.ast)
if si.ast.op in LoadOps: