less lines, and oddly faster

This commit is contained in:
George Hotz
2022-06-18 21:48:42 -07:00
parent aa164d901e
commit 395eb60f46
2 changed files with 12 additions and 19 deletions

View File

@@ -56,39 +56,31 @@ def clbuild(name, prg, options=tuple(), argdtypes=None):
#print(prg)
return CLProgram(name, prg, options, argdtypes)
code_for_op = {
UnaryOps.RELU: 'max(A, (float)0.)', UnaryOps.EXP: 'exp(A)', UnaryOps.LOG: 'log(A)', UnaryOps.NEG: '-A', UnaryOps.SIGN: 'sign(A)',
BinaryOps.ADD: "A+B", BinaryOps.SUB: "A-B", BinaryOps.MUL: "A*B", BinaryOps.DIV: "B/A", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)"
}
def unary_op(op, x):
ret = GPUBuffer(x.shape)
if op == UnaryOps.RELU: code = 'max(a, (float)0.)'
elif op == UnaryOps.EXP: code = 'exp(a)'
elif op == UnaryOps.LOG: code = 'log(a)'
elif op == UnaryOps.NEG: code = '-a'
elif op == UnaryOps.SIGN: code = 'sign(a)'
else: raise Exception(f"{op} isn't supported")
unop = clbuild("unop", """
__kernel void unop(__global const float4 *a_g, __global float4 *res_g) {
int gid = get_global_id(0);
float4 a = a_g[gid];
res_g[gid] = """+code+""";
float4 A = a_g[gid];
res_g[gid] = convert_float4("""+code_for_op[op]+""");
}""")
unop([roundup(prod(ret.shape))//4], None, x.cl, ret.cl)
return ret
def binary_op(op, x, y):
ret = GPUBuffer(x.shape)
if op == BinaryOps.ADD: code = "a+b"
elif op == BinaryOps.SUB: code = "a-b"
elif op == BinaryOps.MUL: code = "a*b"
elif op == BinaryOps.DIV: code = "b/a"
elif op == BinaryOps.POW: code = "pow(a,b)"
elif op == BinaryOps.CMPEQ: code = "(float4)(1.0f*(a.x==b.x), 1.0f*(a.y==b.y), 1.0f*(a.z==b.z), 1.0f*(a.w==b.w))"
else: raise Exception(f"{op} isn't supported")
assert x.shape == ret.shape and y.shape == ret.shape
binop = clbuild("binop", """
__kernel void binop(__global const float4 *a_g, __global const float4 *b_g, __global float4 *res_g) {
int gid = get_global_id(0);
float4 a = a_g[gid];
float4 b = b_g[gid];
res_g[gid] = """+code+""";
float4 A = a_g[gid];
float4 B = b_g[gid];
res_g[gid] = convert_float4("""+code_for_op[op]+""");
}""")
binop([roundup(prod(ret.shape))//4], None, x.cl, y.cl, ret.cl)
return ret

View File

@@ -42,7 +42,8 @@ class Device:
DEFAULT = i if os.environ.get(name, 0) == "1" else DEFAULT
try:
llops[i] = importlib.import_module('tinygrad.llops.'+op)
buffers[i] = [cls for name, cls in inspect.getmembers(llops[i], inspect.isclass) if name.endswith("Buffer")][0]
def find_buffer(llo, name): return [cls for cname, cls in inspect.getmembers(llo, inspect.isclass) if (cname.upper() == name + "BUFFER")][0]
buffers[i] = find_buffer(llops[i], name)
except ImportError as e:
print(op, "not available", e)
DEFAULT = CPU if DEFAULT is None else DEFAULT