Symbolic Shape JIT main PR (#1353)

* Symbolic Shape JIT

update tests

2 variables symbolic ops, adding more tests

test passing

cleanup

* more test cases

* single flag

* review update

* jit attention one piece

* realize

* symbolic_jit test for cuda

* old artifact

* works with cuda gpu but failed ci

* CUDACPU
This commit is contained in:
chenyu
2023-08-18 14:39:55 -07:00
committed by GitHub
parent 84e6693915
commit ae39cf84ab
13 changed files with 223 additions and 65 deletions

View File

@@ -7,7 +7,8 @@ import json
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
for fxn,args in run.jit_cache:
for fxn,args,var_vals in run.jit_cache:
assert not var_vals, "symbolic shape is not supported"
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
cargs = []
for i,arg in enumerate(args):