diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go index e47f1664d..6ccb1ad6d 100644 --- a/x/create/client/quantize.go +++ b/x/create/client/quantize.go @@ -21,7 +21,7 @@ var quantizeParams = map[string]struct { bits int mode string }{ - "int4": {32, 4, "affine"}, + "int4": {64, 4, "affine"}, "nvfp4": {16, 4, "nvfp4"}, "int8": {64, 8, "affine"}, "mxfp8": {32, 8, "mxfp8"}, diff --git a/x/create/create.go b/x/create/create.go index 46b4393b3..9fb9b1e64 100644 --- a/x/create/create.go +++ b/x/create/create.go @@ -334,12 +334,12 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string { quantNorm := normalizeQuantType(quantize) // MLX quantization requires last dimension to be divisible by group size - // nvfp4: 16, int4/mxfp8: 32, int8: 64 + // nvfp4: 16, mxfp8: 32, int4/int8: 64 groupSize := int32(32) switch quantNorm { case "nvfp4": groupSize = 16 - case "int8": + case "int4", "int8": groupSize = 64 } if shape[len(shape)-1]%groupSize != 0 {