Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/training/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ int main(int argc, char ** argv) {
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
params.cache_type_v = GGML_TYPE_F32;
}
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED) {
LOG_INF("%s: force disabling flash attention (no backward pass implementation)\n", __func__);
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
}

llama_backend_init();
llama_numa_init(params.numa);
Expand Down
23 changes: 22 additions & 1 deletion ggml/src/ggml-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,26 @@ ggml_backend_sched_t ggml_backend_sched_new(
return sched;
}

static void ggml_backend_sched_grow_hash_set(ggml_backend_sched_t sched, size_t new_graph_size) {
const size_t new_size = ggml_hash_size(new_graph_size);
if (new_size <= sched->hash_set.size) {
return;
}

ggml_hash_set_free(&sched->hash_set);
free(sched->hv_tensor_backend_ids);
free(sched->hv_tensor_copies);

sched->hash_set = ggml_hash_set_new(new_graph_size);
sched->hv_tensor_backend_ids = (int *) malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
sched->hv_tensor_copies = (ggml_tensor **) malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
GGML_ASSERT(sched->hv_tensor_backend_ids);
GGML_ASSERT(sched->hv_tensor_copies);

memset(sched->hv_tensor_backend_ids, -1, sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
memset(sched->hv_tensor_copies, 0, sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
}

void ggml_backend_sched_free(ggml_backend_sched_t sched) {
if (sched == NULL) {
return;
Expand Down Expand Up @@ -1856,9 +1876,10 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *

bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
GGML_ASSERT(sched);
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
GGML_ASSERT(!sched->is_alloc);

ggml_backend_sched_grow_hash_set(sched, graph->n_nodes + graph->n_leafs);

sched->cur_copy = sched->next_copy;
sched->next_copy = (sched->next_copy + 1) % sched->n_copies;

Expand Down
17 changes: 14 additions & 3 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -6817,6 +6817,11 @@ static void ggml_compute_backward(
case GGML_OP_NONE: {
// noop
} break;
case GGML_OP_SET_ROWS: {
if (src0_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0));
}
} break;
case GGML_OP_COUNT:
default: {
GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
Expand Down Expand Up @@ -6985,6 +6990,7 @@ void ggml_build_backward_expand(
// gradients in node->src[1] for one reason or another have no effect on output gradients
case GGML_OP_CPY: // gradients in CPY target are irrelevant
case GGML_OP_GET_ROWS: // row indices not differentiable
case GGML_OP_SET_ROWS: // row indices not differentiable
case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS
case GGML_OP_ROPE: // positions not differentiable
ignore_src[1] = true;
Expand All @@ -7005,9 +7011,11 @@ void ggml_build_backward_expand(
continue;
}

// inplace operations are currently not supported
// inplace operations: allow ops that have backward implementations
GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE ||
node->op == GGML_OP_SET_ROWS || node->op == GGML_OP_SCALE || node->op == GGML_OP_SET ||
node->op == GGML_OP_ROPE);

const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
GGML_ASSERT(ihash != GGML_HASHSET_FULL);
Expand Down Expand Up @@ -7178,7 +7186,10 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
}

struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
const size_t size = force_grads && 3 * cgraph->n_nodes > cgraph->size
? 3 * cgraph->n_nodes
: cgraph->size;
struct ggml_cgraph * result = ggml_new_graph_custom(ctx, size, cgraph->grads || force_grads);
ggml_graph_cpy(cgraph, result);
return result;
}
Expand Down