mirror of
https://github.com/ollama/ollama.git
synced 2026-03-11 20:23:55 -05:00
append vector
This commit is contained in:
@@ -12,21 +12,16 @@ package mlx
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func doEval(outputs []*Array, async bool) {
|
||||
vectorData := make([]C.mlx_array, 0, len(outputs))
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, output := range outputs {
|
||||
if output.Valid() {
|
||||
vectorData = append(vectorData, output.ctx)
|
||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
if async {
|
||||
C.mlx_async_eval(vector)
|
||||
} else {
|
||||
|
||||
@@ -66,15 +66,13 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
|
||||
}
|
||||
|
||||
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
||||
vectorData := make([]C.mlx_array, len(others)+1)
|
||||
vectorData[0] = t.ctx
|
||||
for i := range others {
|
||||
vectorData[i+1] = others[i].ctx
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, other := range append([]*Array{t}, others...) {
|
||||
C.mlx_vector_array_append_value(vector, other.ctx)
|
||||
}
|
||||
|
||||
out := New("CONCATENATE", t)
|
||||
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user