-
Notifications
You must be signed in to change notification settings - Fork 57
Description
Description:
I am experiencing a significant precision drop when using the quarot algorithm on a device limited to float32 calculations. Originally designed for double precision, the rotations are cast to float32. This leads to a substantial drop in the accuracy of the bf16 model, particularly in the pass1 stage which shows a 10% decrease. Interestingly, the pass8 results remain mostly unchanged.
Expected Behavior:
I expected the precision of the bf16 model to be less affected, assuming float32 would provide sufficient stability in calculations despite the reduction from double.
Observed Behavior:
Pass1 accuracy decrease by 10%
Pass8 accuracy remains nearly unchanged
Steps to Reproduce:
Apply quarot algorithm with float32 calculations on bf16 model.
Observe the precision changes across different passes, notably pass1 and pass8.
Is this behavior typical ? Any solutions?