mirror of
https://github.com/ollama/ollama.git
synced 2026-03-09 07:16:38 -05:00
- Collapse MLX sampling state into a single sample.Sampler struct (options + history). - Replace interface-based sampler chain (TopP, TopK, penalty, etc.) with function-based transforms. - Update request/pipeline wiring to use *sample.Sampler, seed history from prompt tokens, and append generated tokens each step. - Implement top_p, min_p, repeat_penalty, and frequency_penalty
269 lines
6.7 KiB
Go
269 lines
6.7 KiB
Go
//go:build mlx
|
|
|
|
package mlx
|
|
|
|
// #include "generated.h"
|
|
import "C"
|
|
|
|
import (
|
|
"unsafe"
|
|
)
|
|
|
|
func (t *Array) Abs() *Array {
|
|
out := New("ABS")
|
|
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Add(other *Array) *Array {
|
|
out := New("ADD")
|
|
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
|
|
out := New("ADDMM")
|
|
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Argmax(axis int, keepDims bool) *Array {
|
|
out := New("ARGMAX")
|
|
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
|
|
out := New("ARGPARTITION")
|
|
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) ArgsortAxis(axis int) *Array {
|
|
out := New("ARGSORT_AXIS")
|
|
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) AsType(dtype DType) *Array {
|
|
out := New("AS_TYPE")
|
|
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
|
|
cShape := make([]C.int, len(shape))
|
|
for i, s := range shape {
|
|
cShape[i] = C.int(s)
|
|
}
|
|
|
|
cStrides := make([]C.int64_t, len(strides))
|
|
for i, s := range strides {
|
|
cStrides[i] = C.int64_t(s)
|
|
}
|
|
|
|
out := New("AS_STRIDED")
|
|
C.mlx_as_strided(
|
|
&out.ctx, t.ctx,
|
|
unsafe.SliceData(cShape), C.size_t(len(shape)),
|
|
unsafe.SliceData(cStrides), C.size_t(len(strides)),
|
|
C.size_t(offset),
|
|
DefaultStream().ctx,
|
|
)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
|
vector := C.mlx_vector_array_new()
|
|
defer C.mlx_vector_array_free(vector)
|
|
|
|
s := append([]*Array{t}, others...)
|
|
for _, other := range s {
|
|
C.mlx_vector_array_append_value(vector, other.ctx)
|
|
}
|
|
|
|
out := New("CONCATENATE")
|
|
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array {
|
|
out := New("CUMSUM")
|
|
C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Divide(other *Array) *Array {
|
|
out := New("DIVIDE")
|
|
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) ExpandDims(axis int) *Array {
|
|
out := New("EXPAND_DIMS")
|
|
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Flatten(startAxis, endAxis int) *Array {
|
|
out := New("FLATTEN")
|
|
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) FloorDivide(other *Array) *Array {
|
|
out := New("FLOOR_DIVIDE")
|
|
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
|
if lhs == nil {
|
|
lhs = New("")
|
|
}
|
|
if rhs == nil {
|
|
rhs = New("")
|
|
}
|
|
out := New("GATHER_MM")
|
|
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Logsumexp(keepDims bool) *Array {
|
|
out := New("LOGSUMEXP")
|
|
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Less(other *Array) *Array {
|
|
out := New("LESS")
|
|
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Matmul(other *Array) *Array {
|
|
out := New("MATMUL")
|
|
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Multiply(other *Array) *Array {
|
|
out := New("MULTIPLY")
|
|
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Negative() *Array {
|
|
out := New("NEGATIVE")
|
|
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Power(exponent *Array) *Array {
|
|
out := New("POWER")
|
|
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
|
|
out := New("PUT_ALONG_AXIS")
|
|
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Reshape(axes ...int) *Array {
|
|
cAxes := make([]C.int, len(axes))
|
|
for i := range axes {
|
|
cAxes[i] = C.int(axes[i])
|
|
}
|
|
|
|
out := New("RESHAPE")
|
|
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Sigmoid() *Array {
|
|
out := New("SIGMOID")
|
|
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Sqrt() *Array {
|
|
out := New("SQRT")
|
|
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Squeeze(axis int) *Array {
|
|
out := New("SQUEEZE")
|
|
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) StackAxis(axis int, others ...*Array) *Array {
|
|
vectorData := make([]C.mlx_array, len(others)+1)
|
|
vectorData[0] = t.ctx
|
|
for i := range others {
|
|
vectorData[i+1] = others[i].ctx
|
|
}
|
|
|
|
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
|
defer C.mlx_vector_array_free(vector)
|
|
|
|
out := New("STACK_AXIS")
|
|
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Subtract(other *Array) *Array {
|
|
out := New("SUBTRACT")
|
|
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
|
|
out := New("SUM_AXIS")
|
|
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
|
|
out := New("TAKE_AXIS")
|
|
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
|
|
out := New("TAKE_ALONG_AXIS")
|
|
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Tanh() *Array {
|
|
out := New("TANH")
|
|
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func (t *Array) Transpose(axes ...int) *Array {
|
|
cAxes := make([]C.int, len(axes))
|
|
for i, axis := range axes {
|
|
cAxes[i] = C.int(axis)
|
|
}
|
|
|
|
out := New("TRANSPOSE")
|
|
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Zeros(dtype DType, shape ...int) *Array {
|
|
cAxes := make([]C.int, len(shape))
|
|
for i := range shape {
|
|
cAxes[i] = C.int(shape[i])
|
|
}
|
|
|
|
t := New("ZEROS")
|
|
C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
|
return t
|
|
}
|