diff --git a/raft.go b/raft.go index 28c11283..c7973a92 100644 --- a/raft.go +++ b/raft.go @@ -1750,27 +1750,46 @@ func (r *Raft) installSnapshot(rpc RPC, req *InstallSnapshotRequest) { return } - // Separately track the progress of streaming a snapshot over the network - // because this too can take a long time. - countingRPCReader := newCountingReader(rpc.Reader) - - // Spill the remote snapshot to disk - transferMonitor := startSnapshotRestoreMonitor(r.logger, countingRPCReader, req.Size, true) - n, err := io.Copy(sink, countingRPCReader) - transferMonitor.StopAndWait() - if err != nil { - sink.Cancel() - r.logger.Error("failed to copy snapshot", "error", err) - rpcErr = err - return - } + // Spill the remote snapshot to disk. + // Spawn a goroutine to copy the snapshot, so we can handle a + // shutdown signal as well. + diskCopyErrCh := make(chan error, 1) + go func() { + // Separately track the progress of streaming a snapshot over the network + // because this too can take a long time. + countingRPCReader := newCountingReader(rpc.Reader) + transferMonitor := startSnapshotRestoreMonitor(r.logger, countingRPCReader, req.Size, true) + n, err := io.Copy(sink, countingRPCReader) + transferMonitor.StopAndWait() + if err != nil { + r.logger.Error("failed to copy snapshot", "error", err) + diskCopyErrCh <- err + return + } - // Check that we received it all - if n != req.Size { + // Check that we received it all + if n != req.Size { + r.logger.Error("failed to receive whole snapshot", + "received", hclog.Fmt("%d / %d", n, req.Size)) + diskCopyErrCh <- fmt.Errorf("short read") + return + } + + r.logger.Info("copied to local snapshot", "bytes", n) + diskCopyErrCh <- nil + }() + + // Wait for snapshot transfer or shutdown + select { + case err := <-diskCopyErrCh: + if err != nil { + sink.Cancel() + rpcErr = err + return + } + case <-r.shutdownCh: sink.Cancel() - r.logger.Error("failed to receive whole snapshot", - "received", hclog.Fmt("%d / %d", n, req.Size)) - rpcErr = fmt.Errorf("short read") + rpcErr = ErrRaftShutdown return } @@ -1780,7 +1799,6 @@ func (r *Raft) installSnapshot(rpc RPC, req *InstallSnapshotRequest) { rpcErr = err return } - r.logger.Info("copied to local snapshot", "bytes", n) // Restore snapshot future := &restoreFuture{ID: sink.ID()}