Here is my PyTorch implementation for computing the average quaternion
def mean(Q, weights=None): if weights is None: weights = torch.ones(len(Q), device=torch.device("cuda:0")) / len(Q) A = torch.zeros((4, 4), device=torch.device("cuda:0")) weight_sum = torch.sum(weights) oriented_Q = ((Q[:, 0:1] > 0).float() - 0.5) * 2 * Q A = torch.einsum("bi,bk->bik", (oriented_Q, oriented_Q)) A = torch.sum(torch.einsum("bij,b->bij", (A, weights)), 0) A /= weight_sum q_avg = torch.linalg.eigh(A)[1][:, -1] if q_avg[0] < 0: return -q_avg return q_avg
I made use of the algorithm in http://tbirdal.blogspot.com/2019/10/i-allocate-this-post-to-providing.html which is "Averaging quaternions." by Markley, F. Landis, et al.