convert(gptoss): mxfp4 to ggml layout to avoid jit conversion (#12018)

* convert: return bytes written

* ggml flavor mxfp4

* simplify jit conversion

* comment
This commit is contained in:
Michael Yang
2025-08-26 16:41:02 -07:00
committed by GitHub
parent 86834a2797
commit 59412fbb43
6 changed files with 49 additions and 58 deletions

View File

@@ -172,7 +172,20 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
blocksDims[i] = int(d)
}
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes()))
bts := b.Bytes()
var tmp [16]byte
for i := 0; i < b.Len(); i += 16 {
for j := range 8 {
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
a, b := bts[i+j], bts[i+j+8]
tmp[2*j+0] = (a & 0x0F) | (b << 4)
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
}
copy(bts[i:i+16], tmp[:])
}
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
var s bytes.Buffer
if _, err := m.scales.WriteTo(&s); err != nil {
@@ -206,5 +219,5 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
return 0, err
}
return 0, nil
return int64(len(u8s)), nil
}

View File

@@ -33,8 +33,8 @@ func (t tensorBase) Shape() []uint64 {
const (
tensorKindFP32 uint32 = iota
tensorKindFP16
tensorKindMXFP4 = 4
tensorKindBF16 = 30
tensorKindMXFP4 = 39
)
func (t tensorBase) Kind() uint32 {

View File

@@ -188,17 +188,17 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
switch st.Kind() {
case tensorKindFP32:
return 0, binary.Write(w, binary.LittleEndian, f32s)
return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
case tensorKindFP16:
f16s := make([]uint16, len(f32s))
for i := range f32s {
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
}
return 0, binary.Write(w, binary.LittleEndian, f16s)
return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
case tensorKindBF16:
u8s := bfloat16.EncodeFloat32(f32s)
return 0, binary.Write(w, binary.LittleEndian, u8s)
return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
default:
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
}