@@ -1,7 +1,8 @@
import collections , time
from typing import List , Any , Dict , cast , Optional , Tuple , Set
from tinygrad . helpers import round_up , to_mv , PROFILE
from tinygrad . device import HCQCompiled , HCQAllocator , HCQSignal , Buffer , BufferOptions , Compiled , Device
from tinygrad . device import HCQCompiled , HCQAllocator , HCQSignal , HCQ Buffer, HWCommandQueue , HW ComputeQueue , HWCopyQueue , \
Buffer , BufferOptions , Compiled , Device
from tinygrad . shape . symbolic import Variable
from tinygrad . engine . realize import ExecItem , BufferXfer , CompiledRunner
from tinygrad . engine . jit import MultiGraphRunner
@@ -9,14 +10,14 @@ from tinygrad.engine.jit import MultiGraphRunner
class HCQGraph ( MultiGraphRunner ) :
def __init__ ( self , jit_cache : List [ ExecItem ] , input_rawbuffers : List [ Buffer ] , var_vals : Dict [ Variable , int ] ) :
super ( ) . __init__ ( jit_cache , input_rawbuffers , var_vals )
self . devices = list ( set ( cast ( Any , d ) for ji in jit_cache for d in [ Device [ cast ( Buffer , x ) . device ] for x in ji . bufs ] ) )
self . devices = list ( set ( cast ( HCQCompiled , d ) for ji in jit_cache for d in [ Device [ cast ( Buffer , x ) . device ] for x in ji . bufs ] ) )
# Allocate kernel args.
kernargs_size : Dict [ Compiled , int ] = collections . defaultdict ( int )
for ji in self . jit_cache :
if not isinstance ( ji . prg , CompiledRunner ) : continue
kernargs_size [ ji . prg . device ] + = round_up ( ji . prg . clprg . kernargs_alloc_size , 16 )
self . kernargs_bufs : Dict [ Compiled , Any ] = { dev : dev . allocator . _alloc ( sz , BufferOptions ( cpu_access = True ) ) for dev , sz in kernargs_size . items ( ) }
self . kernargs_bufs : Dict [ Compiled , HCQBuffer ] = { dev : dev . allocator . _alloc ( sz , BufferOptions ( cpu_access = True ) ) for dev , sz in kernargs_size . items ( ) }
kernargs_ptrs : Dict [ Compiled , int ] = { dev : buf . va_addr for dev , buf in self . kernargs_bufs . items ( ) }
# Fill initial arguments.
@@ -37,19 +38,19 @@ class HCQGraph(MultiGraphRunner):
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’ s
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
self . comp_queues : Dict [ Compiled , Any ] = { dev : dev . hw_compute_queue_t ( ) for dev in self . devices }
self . copy_queues : Dict [ Compiled , Any ] = { dev : dev . hw_copy_queue_t ( ) for dev in self . devices }
self . comp_queues : Dict [ Compiled , HWComputeQueue ] = { dev : dev . hw_compute_queue_t ( ) for dev in self . devices }
self . copy_queues : Dict [ Compiled , HWCopyQueue ] = { dev : dev . hw_copy_queue_t ( ) for dev in self . devices }
self . signal_sched : Dict [ int , Tuple [ List , Any , Optional [ int ] , Optional [ List ] ] ] = { } # Dict[ji_idx, (deps, signal, sigval, prof_info)]
self . signals = { q : self . devices [ 0 ] . signal_t ( value = 0 ) for q in list ( self . comp_queues . values ( ) ) + list ( self . copy_ queues. value s( ) ) }
self . dev_kickoff_signal = { dev : self . devices [ 0 ] . signal_t ( value = 0 ) for dev in self . devices + [ ' CPU ' ] } # Dict[dev, signal]
self . signal_sched : Dict [ int , Tuple [ List , HCQSignal , Optional [ int ] , Optional [ List ] ] ] = { } # Dict[ji_idx, (deps, signal, sigval, prof_info)]
self . signals = { q : dev . signal_t ( value = 0 ) for queues in ( self . comp_queues , self . copy_queues ) for dev , q in queues. item s( ) } #type:ignore
self . dev_kickoff_signal = { * * { dev . dname : dev . signal_t ( value = 0 ) for dev in self . devices } , * * { " CPU " : self . devices [ 0 ] . signal_t ( value = 0 ) } }
self . kickoff_value = 0
self . save_devs : Dict [ Any , Set ] = { q : set ( ) for q in list ( self . comp_queues . values ( ) ) + list ( self . copy_queues . values ( ) ) }
self . save_devs : Dict [ HWCommandQueue , Set ] = { q : set ( ) for q in list ( self . comp_queues . values ( ) ) + list ( self . copy_queues . values ( ) ) }
for dev in self . devices : self . save_devs [ self . comp_queues [ dev ] ] . add ( dev )
self . last_timeline : Dict [ HCQCompiled , Tuple [ HCQSignal , int ] ] = { dev : ( dev . timeline_signal , 0 ) for dev in self . devices }
self . last_ji : Dict [ Any , Any ] = { q : None for q in list ( self . comp_queues . values ( ) ) + list ( self . copy_queues . values ( ) ) }
self . last_ji : Dict [ HWCommandQueue , Optional [ int ] ] = { q : None for q in list ( self . comp_queues . values ( ) ) + list ( self . copy_queues . values ( ) ) }
for j , ji in enumerate ( self . jit_cache ) :
enqueue_dev = ji . prg . device if isinstance ( ji . prg , CompiledRunner ) else Device [ ji . bufs [ 1 ] . device ] #type:ignore
@@ -80,11 +81,11 @@ class HCQGraph(MultiGraphRunner):
# Build hardware queues.
self . op_cmd_idx : Dict [ int , Tuple [ Any , int ] ] = { }
self . copy_to_devs : Dict [ Compiled , Set [ Compiled ] ] = { dev : set ( ) for dev in self . devices }
self . kickoff_wait_cmds : Dict [ Any , List ] = { q : list ( ) for q in list ( self . comp_queues . values ( ) ) + list ( self . copy_queues . values ( ) ) }
self . kickoff_wait_cmds : Dict [ HWCommandQueue , List ] = { q : list ( ) for q in list ( self . comp_queues . values ( ) ) + list ( self . copy_queues . values ( ) ) }
for dev in self . devices :
self . comp_queues [ dev ] . memory_barrier ( ) . wait ( dev . timeline_signal , dev . timeline_value - 1 ) \
. wait ( self . dev_kickoff_signal [ ' CPU ' ] , self . kickoff_value ) . signal ( self . dev_kickoff_signal [ dev ] , self . kickoff_value )
. wait ( self . dev_kickoff_signal [ ' CPU ' ] , self . kickoff_value ) . signal ( self . dev_kickoff_signal [ dev . dname ] , self . kickoff_value )
for j , ji in enumerate ( self . jit_cache ) :
deps , signal , signal_val , prof_info = self . signal_sched [ j ]
@@ -97,11 +98,12 @@ class HCQGraph(MultiGraphRunner):
if prof_info : enqueue_queue . timestamp ( prof_info [ 0 ] )
# Encode main commands based on ji type.
if isinstance ( ji . prg , CompiledRunner ) : enqueue_queue . exec ( ji . prg . clprg , self . kargs_addrs [ j ] , * ji . prg . p . launch_dims ( var_vals ) )
if isinstance ( ji . prg , CompiledRunner ) :
cast ( HWComputeQueue , enqueue_queue ) . exec ( ji . prg . clprg , self . kargs_addrs [ j ] , * ji . prg . p . launch_dims ( var_vals ) )
elif isinstance ( ji . prg , BufferXfer ) :
dest , src = [ cast ( Buffer , x ) for x in ji . bufs [ 0 : 2 ] ]
cast ( HCQAllocator , Device [ src . device ] . allocator ) . map ( dest . _buf )
enqueue_queue . copy ( dest . _buf . va_addr , src . _buf . va_addr , dest . nbytes )
cast ( HWCopyQueue , enqueue_queue) . copy ( dest . _buf . va_addr , src . _buf . va_addr , dest . nbytes )
self . copy_to_devs [ Device [ dest . device ] ] . add ( Device [ src . device ] )
self . op_cmd_idx [ j ] = ( enqueue_queue , len ( enqueue_queue ) - 1 )
@@ -116,15 +118,15 @@ class HCQGraph(MultiGraphRunner):
self . comp_queues [ dev ] . wait ( self . signals [ self . copy_queues [ dep_dev ] ] , self . signal_sched [ last_j ] [ 2 ] )
self . comp_queues [ dev ] . signal ( dev . timeline_signal , dev . timeline_value )
if hasattr ( self . comp_queues [ dev ] , ' bind ' ) : self . comp_queues [ dev ] . bind ( dev )
if hasattr ( self . copy_queues [ dev ] , ' bind ' ) and self . last_ji [ self . copy_queues [ dev ] ] is not None : self . copy_queues [ dev ] . bind ( dev )
self . comp_queues [ dev ] . bind ( dev )
if self . last_ji [ self . copy_queues [ dev ] ] is not None : self . copy_queues [ dev ] . bind ( dev )
def __call__ ( self , input_rawbuffers : List [ Buffer ] , var_vals : Dict [ Variable , int ] , wait = False ) - > Optional [ float ] :
# Wait and restore signals
self . kickoff_value + = 1
for dev in self . devices : self . last_timeline [ dev ] [ 0 ] . wait ( self . last_timeline [ dev ] [ 1 ] )
for queue in self . comp_queues . values ( ) : self . signals [ queue ] . value = 0
for queue in self . copy_queues . values ( ) : self . signals [ queue ] . value = 0
for comp_ queue in self . comp_queues . values ( ) : self . signals [ comp_ queue] . value = 0
for copy_ queue in self . copy_queues . values ( ) : self . signals [ copy_ queue] . value = 0
self . dev_kickoff_signal [ ' CPU ' ] . value = self . kickoff_value
if PROFILE and self . kickoff_value > 1 :
@@ -171,7 +173,7 @@ class HCQGraph(MultiGraphRunner):
for buf in read + write :
if buf . device not in self . save_devs [ queue ] :
self . save_devs [ queue ] . add ( buf . device )
sync_signals + = [ ( self . dev_kickoff_signal [ Device [ buf . device ] ] , self . kickoff_value ) ]
sync_signals + = [ ( self . dev_kickoff_signal [ Device [ buf . device ] . dname ] , self . kickoff_value ) ]
return [ ( self . signals [ k ] , max ( v for x , v in deps if id ( x ) == idk ) ) for idk , k in { id ( x [ 0 ] ) : x [ 0 ] for x in deps } . items ( ) ] + sync_signals
@@ -182,4 +184,4 @@ class HCQGraph(MultiGraphRunner):
if PROFILE and self . kickoff_value > 1 :
for _ , _ , _ , ( st , en , dev , desc , is_cp ) in self . signal_sched . values ( ) : dev . sig_prof_records + = [ ( st , en , desc , is_cp ) ] #type: ignore
for dev , buf in self . kernargs_bufs . items ( ) : dev . allocator . _free ( buf , BufferOptions ( cpu_access = True ) )
for f dev, buf in self . kernargs_bufs . items ( ) : f dev. allocator . _free ( buf , BufferOptions ( cpu_access = True ) )