mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
mixtral touch up: two lines
This commit is contained in:
@@ -16,7 +16,8 @@ class MixtureFeedForward:
|
||||
top = sorted(enumerate(choice), key=lambda x: -x[1])
|
||||
norm = top[0][1] + top[1][1]
|
||||
e1, e2 = self.experts[top[0][0]], self.experts[top[1][0]]
|
||||
ret = e1(x.to(e1.w1.weight.device)).to(x.device) * (top[0][1]/norm) + e2(x.to(e2.w1.weight.device)).to(x.device) * (top[1][1]/norm)
|
||||
ret = e1(x.to(e1.w1.weight.device)).to(x.device) * (top[0][1]/norm) + \
|
||||
e2(x.to(e2.w1.weight.device)).to(x.device) * (top[1][1]/norm)
|
||||
return ret
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user