diff --git a/examples/training/finetune.cpp b/examples/training/finetune.cpp index 0a75ac110ca..4c8633e4bdb 100644 --- a/examples/training/finetune.cpp +++ b/examples/training/finetune.cpp @@ -39,6 +39,10 @@ int main(int argc, char ** argv) { LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); params.cache_type_v = GGML_TYPE_F32; } + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED) { + LOG_INF("%s: force disabling flash attention (no backward pass implementation)\n", __func__); + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + } llama_backend_init(); llama_numa_init(params.numa); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 1a555bf2a4d..4d016775cc6 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1786,6 +1786,26 @@ ggml_backend_sched_t ggml_backend_sched_new( return sched; } +static void ggml_backend_sched_grow_hash_set(ggml_backend_sched_t sched, size_t new_graph_size) { + const size_t new_size = ggml_hash_size(new_graph_size); + if (new_size <= sched->hash_set.size) { + return; + } + + ggml_hash_set_free(&sched->hash_set); + free(sched->hv_tensor_backend_ids); + free(sched->hv_tensor_copies); + + sched->hash_set = ggml_hash_set_new(new_graph_size); + sched->hv_tensor_backend_ids = (int *) malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); + sched->hv_tensor_copies = (ggml_tensor **) malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); + GGML_ASSERT(sched->hv_tensor_backend_ids); + GGML_ASSERT(sched->hv_tensor_copies); + + memset(sched->hv_tensor_backend_ids, -1, sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); + memset(sched->hv_tensor_copies, 0, sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); +} + void ggml_backend_sched_free(ggml_backend_sched_t sched) { if (sched == NULL) { return; @@ -1856,9 +1876,10 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { GGML_ASSERT(sched); - GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs); GGML_ASSERT(!sched->is_alloc); + ggml_backend_sched_grow_hash_set(sched, graph->n_nodes + graph->n_leafs); + sched->cur_copy = sched->next_copy; sched->next_copy = (sched->next_copy + 1) % sched->n_copies; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 0142498d967..9459237cc7e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6817,6 +6817,11 @@ static void ggml_compute_backward( case GGML_OP_NONE: { // noop } break; + case GGML_OP_SET_ROWS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0)); + } + } break; case GGML_OP_COUNT: default: { GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); @@ -6985,6 +6990,7 @@ void ggml_build_backward_expand( // gradients in node->src[1] for one reason or another have no effect on output gradients case GGML_OP_CPY: // gradients in CPY target are irrelevant case GGML_OP_GET_ROWS: // row indices not differentiable + case GGML_OP_SET_ROWS: // row indices not differentiable case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS case GGML_OP_ROPE: // positions not differentiable ignore_src[1] = true; @@ -7005,9 +7011,11 @@ void ggml_build_backward_expand( continue; } - // inplace operations are currently not supported + // inplace operations: allow ops that have backward implementations GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW || - node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE); + node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE || + node->op == GGML_OP_SET_ROWS || node->op == GGML_OP_SCALE || node->op == GGML_OP_SET || + node->op == GGML_OP_ROPE); const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node); GGML_ASSERT(ihash != GGML_HASHSET_FULL); @@ -7178,7 +7186,10 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { } struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) { - struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads); + const size_t size = force_grads && 3 * cgraph->n_nodes > cgraph->size + ? 3 * cgraph->n_nodes + : cgraph->size; + struct ggml_cgraph * result = ggml_new_graph_custom(ctx, size, cgraph->grads || force_grads); ggml_graph_cpy(cgraph, result); return result; }