mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 09:05:40 -05:00
ops_python: add image support (#3356)
* ops_python: add image support * uops tests in their own CI * fix ci
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# this is the (living) definition of uops
|
||||
from typing import Tuple, List, Optional, Any, Dict
|
||||
import pickle, base64, itertools, time, math
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import all_same, getenv
|
||||
from tinygrad.device import Compiled, Allocator, Compiler
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
@@ -32,15 +32,18 @@ def exec_alu(arg, dtype, p):
|
||||
raise NotImplementedError(f"no support for {arg}")
|
||||
|
||||
def _load(m, i):
|
||||
if i<0 or i>=len(m): raise IndexError(f"access out of bounds, size is {len(m)} and access is {i}")
|
||||
if i<0 or i>=len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
||||
return m[i]
|
||||
def load(inp, j=0):
|
||||
if len(inp) == 4:
|
||||
return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
|
||||
else:
|
||||
assert len(inp) == 2, "image loads not supported yet"
|
||||
return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
|
||||
|
||||
def _store(m, i, v):
|
||||
if i<0 or i>=len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
|
||||
m[i] = v
|
||||
|
||||
class PythonProgram:
|
||||
def __init__(self, name:str, lib:bytes):
|
||||
self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib)
|
||||
@@ -58,12 +61,21 @@ class PythonProgram:
|
||||
uop, dtype, idp, arg = self.uops[i]
|
||||
inp = [ul[v] for v in idp]
|
||||
dtp = [dl[v] for v in idp]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
||||
if uop is UOps.STORE:
|
||||
if dtp[2].sz > 1:
|
||||
assert len(inp) <= 3, "gated stores not supported yet"
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
# image store
|
||||
assert dtp[2].sz == 4
|
||||
for j,val in enumerate(inp[2]):
|
||||
for m,o,v in zip(inp[0], inp[1], val): m[o+j] = v
|
||||
for m,ox,oy,v in zip(inp[0], inp[1][0], inp[1][1], val):
|
||||
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
|
||||
_store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
|
||||
elif dtp[2].sz > 1:
|
||||
for j,val in enumerate(inp[2]):
|
||||
for m,o,v in zip(inp[0], inp[1], val): _store(m, o+j, v)
|
||||
else:
|
||||
for m,o,v in zip(*inp): m[o] = v
|
||||
for m,o,v in zip(*inp): _store(m, o, v)
|
||||
i += 1
|
||||
continue
|
||||
elif uop is UOps.END:
|
||||
@@ -115,7 +127,16 @@ class PythonProgram:
|
||||
else:
|
||||
ul[i] = inp[0]
|
||||
elif uop is UOps.LOAD:
|
||||
if dtype.sz > 1:
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
assert dtype.sz == 4
|
||||
ul[i] = []
|
||||
for j in range(dtype.sz):
|
||||
ret = []
|
||||
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
|
||||
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
|
||||
else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
|
||||
ul[i].append(ret)
|
||||
elif dtype.sz > 1:
|
||||
ul[i] = [load(inp, j) for j in range(dtype.sz)]
|
||||
else:
|
||||
ul[i] = load(inp)
|
||||
@@ -155,7 +176,6 @@ class PythonProgram:
|
||||
assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
|
||||
ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
|
||||
assert i in ul, (uop, dtype, idp, arg)
|
||||
#print(i, uop, dtype, arg, ul[i] if i in ul else None)
|
||||
i += 1
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
Reference in New Issue
Block a user