append vector

This commit is contained in:
Michael Yang
2026-02-06 14:47:34 -08:00
parent e19fbe7369
commit 1ec216fe0a
2 changed files with 9 additions and 16 deletions

View File

@@ -12,21 +12,16 @@ package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func doEval(outputs []*Array, async bool) {
vectorData := make([]C.mlx_array, 0, len(outputs))
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
for _, output := range outputs {
if output.Valid() {
vectorData = append(vectorData, output.ctx)
C.mlx_vector_array_append_value(vector, output.ctx)
}
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
if async {
C.mlx_async_eval(vector)
} else {

View File

@@ -66,15 +66,13 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
}
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
vectorData := make([]C.mlx_array, len(others)+1)
vectorData[0] = t.ctx
for i := range others {
vectorData[i+1] = others[i].ctx
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
for _, other := range append([]*Array{t}, others...) {
C.mlx_vector_array_append_value(vector, other.ctx)
}
out := New("CONCATENATE", t)
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out