Skip to content

Precision Loss in bf16 Model with float32 Rotation Calculation #72

@Niko-zyf

Description

@Niko-zyf

Description:

I am experiencing a significant precision drop when using the quarot algorithm on a device limited to float32 calculations. Originally designed for double precision, the rotations are cast to float32. This leads to a substantial drop in the accuracy of the bf16 model, particularly in the pass1 stage which shows a 10% decrease. Interestingly, the pass8 results remain mostly unchanged.

Expected Behavior:

I expected the precision of the bf16 model to be less affected, assuming float32 would provide sufficient stability in calculations despite the reduction from double.

Observed Behavior:

Pass1 accuracy decrease by 10%
Pass8 accuracy remains nearly unchanged

Steps to Reproduce:

Apply quarot algorithm with float32 calculations on bf16 model.
Observe the precision changes across different passes, notably pass1 and pass8.

Is this behavior typical ? Any solutions?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions