mixtral touch up: two lines

This commit is contained in:
George Hotz
2023-12-10 17:21:49 -08:00
parent b3982187d1
commit b01e3907a1

View File

@@ -16,7 +16,8 @@ class MixtureFeedForward:
top = sorted(enumerate(choice), key=lambda x: -x[1])
norm = top[0][1] + top[1][1]
e1, e2 = self.experts[top[0][0]], self.experts[top[1][0]]
ret = e1(x.to(e1.w1.weight.device)).to(x.device) * (top[0][1]/norm) + e2(x.to(e2.w1.weight.device)).to(x.device) * (top[1][1]/norm)
ret = e1(x.to(e1.w1.weight.device)).to(x.device) * (top[0][1]/norm) + \
e2(x.to(e2.w1.weight.device)).to(x.device) * (top[1][1]/norm)
return ret
if __name__ == "__main__":