ops_python: add image support (#3356)

* ops_python: add image support

* uops tests in their own CI

* fix ci
This commit is contained in:
George Hotz
2024-02-09 12:02:06 +01:00
committed by GitHub
parent 5f93061f67
commit 7726eef464
2 changed files with 54 additions and 10 deletions

View File

@@ -3,7 +3,7 @@
# this is the (living) definition of uops
from typing import Tuple, List, Optional, Any, Dict
import pickle, base64, itertools, time, math
from tinygrad.dtype import DType, dtypes
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import all_same, getenv
from tinygrad.device import Compiled, Allocator, Compiler
from tinygrad.codegen.uops import UOp, UOps
@@ -32,15 +32,18 @@ def exec_alu(arg, dtype, p):
raise NotImplementedError(f"no support for {arg}")
def _load(m, i):
if i<0 or i>=len(m): raise IndexError(f"access out of bounds, size is {len(m)} and access is {i}")
if i<0 or i>=len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
return m[i]
def load(inp, j=0):
if len(inp) == 4:
return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
else:
assert len(inp) == 2, "image loads not supported yet"
return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
def _store(m, i, v):
if i<0 or i>=len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
m[i] = v
class PythonProgram:
def __init__(self, name:str, lib:bytes):
self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib)
@@ -58,12 +61,21 @@ class PythonProgram:
uop, dtype, idp, arg = self.uops[i]
inp = [ul[v] for v in idp]
dtp = [dl[v] for v in idp]
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
if uop is UOps.STORE:
if dtp[2].sz > 1:
assert len(inp) <= 3, "gated stores not supported yet"
if isinstance(dtp[0], ImageDType):
# image store
assert dtp[2].sz == 4
for j,val in enumerate(inp[2]):
for m,o,v in zip(inp[0], inp[1], val): m[o+j] = v
for m,ox,oy,v in zip(inp[0], inp[1][0], inp[1][1], val):
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
_store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
elif dtp[2].sz > 1:
for j,val in enumerate(inp[2]):
for m,o,v in zip(inp[0], inp[1], val): _store(m, o+j, v)
else:
for m,o,v in zip(*inp): m[o] = v
for m,o,v in zip(*inp): _store(m, o, v)
i += 1
continue
elif uop is UOps.END:
@@ -115,7 +127,16 @@ class PythonProgram:
else:
ul[i] = inp[0]
elif uop is UOps.LOAD:
if dtype.sz > 1:
if isinstance(dtp[0], ImageDType):
assert dtype.sz == 4
ul[i] = []
for j in range(dtype.sz):
ret = []
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
ul[i].append(ret)
elif dtype.sz > 1:
ul[i] = [load(inp, j) for j in range(dtype.sz)]
else:
ul[i] = load(inp)
@@ -155,7 +176,6 @@ class PythonProgram:
assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
assert i in ul, (uop, dtype, idp, arg)
#print(i, uop, dtype, arg, ul[i] if i in ul else None)
i += 1
return time.perf_counter() - st