t = int(input()) sizes = [] def square(size): size.sort() if size[0] == size[1]: return (size[0]*2)**2 if size[0]*2 <= size[1]: return max(size)**2 while size[1] % size[0] != 0: size[1] += 1 return size[1]**2; for i in range(t): size = [int(x) for x in input().split(" ")] print(square(size)) """ def check(): for i in range(1, 101): for j in range(1, 101): if square([i,j]) == 36 and i != 6 and j != 6: print(i,j) check() """