Hi, thanks for great work. While I am looking on KLCPD code, I wonder how mmd2 loss is calculated ? especially [this part](https://github.com/OctoberChang/klcpd_code/blob/b3a32ee4ce5a950cefa2319535d2b2f521e7176f/mmd_util.py#L26) Can you give me any reference or explanation how this code was derived ? Thanks, Best regards, YJHong.