diff --git a/spec/pg/replication_spec.cr b/spec/pg/replication_spec.cr new file mode 100644 index 0000000..9115bf9 --- /dev/null +++ b/spec/pg/replication_spec.cr @@ -0,0 +1,319 @@ +require "../spec_helper" + +class TestMessageHandler + include PG::Replication::Handler + + getter relations : Hash(Int32, PG::Replication::Relation) = Hash(Int32, PG::Replication::Relation).new + getter types = Hash(Int32, PG::Replication::Type).new + getter truncations = [] of PG::Replication::Truncate + getter transaction : Transaction? + getter data = Hash(Int32, Hash(Bytes, PG::Replication::WALMessage::TupleData)).new { |h, k| + h[k] = {} of Bytes => PG::Replication::WALMessage::TupleData + } + + def received(data : PG::Replication::XLogData, connection : PG::Replication::Connection, &) + yield + connection.last_wal_byte_flushed = data.wal_end + connection.last_wal_byte_applied = data.wal_end + end + + def received(msg : PG::Replication::Begin) + if transaction + raise "We are already running a transaction" + end + + @transaction = Transaction.new( + id: msg.transaction_id, + final_lsn: msg.final_lsn, + timestamp: msg.timestamp, + ) + end + + def received(msg : PG::Replication::Message) + end + + def received(msg : PG::Replication::Commit) + transaction!.events.each do |event| + if relation = relations[event.oid] + # There can be multiple parts of the primary key + case event.type + in .insert?, .update? + if (key = relation.columns.map_with_index { |column, index| index if column.flags.key? }.compact).any? + data[event.oid][Slice.join(key.map { |key_part_index| make_key(event.tuple_data[key_part_index]) })] = event.tuple_data + else + # What do? + end + in .delete? + if (key = relation.columns.map_with_index { |column, index| index if column.flags.key? }.compact).any? + data[event.oid].delete Slice.join(key.map { |key_part_index| make_key(event.tuple_data[key_part_index]) }) + else + # What do? + end + end + end + end + @transaction = nil + end + + def received(msg : PG::Replication::Origin) + end + + def received(msg : PG::Replication::Relation) + relations[msg.oid] = msg + end + + def received(type : PG::Replication::Type) + types[type.oid] = type + end + + def received(msg : PG::Replication::Insert) + transaction!.insert oid: msg.oid, tuple_data: msg.tuple_data + end + + def received(msg : PG::Replication::Update) + transaction!.update oid: msg.oid, tuple_data: msg.new_tuple_data, + old_tuple_data: msg.old_tuple_data, + key_tuple_data: msg.key_tuple_data + end + + def received(msg : PG::Replication::Delete) + transaction!.delete oid: msg.oid, tuple_data: {msg.key_tuple_data, msg.old_tuple_data}.first.not_nil! + end + + def received(msg : PG::Replication::Truncate) + truncations << msg + end + + # TODO: Implement the methods below in order to test higher `proto_version`s + + # Requires proto_version >= 2 + # def received(msg : PG::Replication::StreamStart) + # end + + # def received(msg : PG::Replication::StreamStop) + # end + + # def received(msg : PG::Replication::StreamCommit) + # end + + # def received(msg : PG::Replication::StreamAbort) + # end + + # Requires proto_version >= 3 + # def received(msg : PG::Replication::BeginPrepare) + # end + + # def received(msg : PG::Replication::Prepare) + # end + + # def received(msg : PG::Replication::CommitPrepared) + # end + + # def received(msg : PG::Replication::RollbackPrepared) + # end + + # def received(msg : PG::Replication::StreamPrepare) + # end + + private def transaction! + transaction.not_nil! + end + + private def make_key(value : String | Bytes) : Bytes + value.to_slice + end + + private def make_key(value : PG::Replication::WALMessage::UnchangedTOASTValue) : Bytes + raise ArgumentError.new("Using a TOASTed value in a primary key is unsupported") + end + + private def make_key(value : Nil) : Bytes + Bytes.empty + end + + class Transaction + getter id : Int32 + getter final_lsn : Int64 + getter timestamp : Time + getter events : Array(Event) + + alias TupleData = PG::Replication::WALMessage::TupleData + + def initialize(@id, @final_lsn, @timestamp, @events = [] of Event) + end + + def insert(oid : Int32, tuple_data : PG::Replication::WALMessage::TupleData) + events << Event.new(oid, tuple_data, :insert) + end + + def update(oid : Int32, tuple_data : TupleData, old_tuple_data : TupleData?, key_tuple_data : TupleData?) + events << Event.new(oid, tuple_data, :update, old_tuple_data: old_tuple_data, key_tuple_data: key_tuple_data) + end + + def delete(oid : Int32, tuple_data : PG::Replication::WALMessage::TupleData) + events << Event.new(oid, tuple_data, :delete) + end + + record Event, + oid : Int32, + tuple_data : TupleData, + type : Type, + old_tuple_data : TupleData? = nil, + key_tuple_data : TupleData? = nil do + enum Type + Insert + Update + Delete + end + end + end +end + +if PG_DB.query_one("SHOW wal_level", as: String) == "logical" + describe PG::Replication do + it_consumes_wal "new relations" do |handler, context| + PG_DB.exec "CREATE TABLE #{context.table_name} (id UUID PRIMARY KEY, string TEXT)" + # Apparently the Relation message isn't sent until after data is inserted + # into the table. Without this insert, the `wait_for` call times out. + PG_DB.exec "INSERT INTO #{context.table_name} (id, string) VALUES ($1, $2)", UUID.v7, "yep" + + wait_for { handler.relations.any? } + oid, relation = handler.relations.first + + relation.namespace.should eq "public" + relation.name.should eq context.table_name + # The `id` column must be indicated as part of the relation's primary key + relation.columns + .find! { |column| column.name == "id" } + .flags.key?.should eq true + end + + it_consumes_wal "inserts" do |handler, context| + id = UUID.v7 + string = "asdf" + PG_DB.exec "CREATE TABLE #{context.table_name} (id UUID PRIMARY KEY, string TEXT)" + PG_DB.exec "INSERT INTO #{context.table_name} (id, string) VALUES ($1, $2)", id, string + + wait_for { handler.data.any?(&.last.any?) } + + _, records = handler.data.first + pk, tuple = records.first + pk.should eq id.bytes.to_slice + tuple[1].should eq string.to_slice + end + + it_consumes_wal "updates" do |handler, context| + id = UUID.v7 + string = "omg" + PG_DB.exec "CREATE TABLE #{context.table_name} (id UUID PRIMARY KEY, string TEXT)" + PG_DB.exec "INSERT INTO #{context.table_name} (id, string) VALUES ($1, $2)", id, string + PG_DB.exec "UPDATE #{context.table_name} SET string = 'lol' WHERE id = $1", id + + # Wait for at least the insert to propagate + wait_for { handler.data.any?(&.last.any?) } + # Give the update just a little longer to come in + wait_for "record to be updated" do + _, records = handler.data.first + _, tuple = records.first + tuple[1] == "lol".to_slice + end + end + + it_consumes_wal "deletes" do |handler, context| + id = UUID.v7 + string = "omg" + PG_DB.exec "CREATE TABLE #{context.table_name} (id UUID PRIMARY KEY, string TEXT)" + PG_DB.exec "INSERT INTO #{context.table_name} (id, string) VALUES ($1, $2)", id, string + PG_DB.exec "DELETE FROM #{context.table_name} WHERE id = $1", id + + # Wait for at least the insert to propagate + wait_for { handler.data.any?(&.last.any?) } + # Give the delete just a little longer to come in + wait_for "data to be deleted" { handler.data.first.last.none? } + end + + it_consumes_wal "new types" do |handler, context| + PG_DB.exec "CREATE TYPE #{context.type_name} AS ENUM ('one', 'two', 'three')" + # The type isn't sent until there's a table that uses it + PG_DB.exec "CREATE TABLE #{context.table_name} (id UUID PRIMARY KEY DEFAULT gen_random_uuid(), thing #{context.type_name})" + # ... and the table isn't sent until there's a record in it + PG_DB.exec "INSERT INTO #{context.table_name} (thing) VALUES ('three')" + + # Wait for the insert to propagate + wait_for { handler.types.any? } + + handler.types.first.last.data_type.should eq context.type_name + end + + it_consumes_wal "schema changes" do |handler, context| + PG_DB.exec "CREATE TABLE #{context.table_name} (id UUID PRIMARY KEY DEFAULT gen_random_uuid())" + PG_DB.exec "INSERT INTO #{context.table_name} (id) VALUES ($1)", UUID.v7 + wait_for { handler.relations.any? } + # Make sure we get that the table has 1 column before proceeding + wait_for { handler.relations.first.last.columns.size == 1 } + + PG_DB.exec "ALTER TABLE #{context.table_name} ADD COLUMN string TEXT" + PG_DB.exec "INSERT INTO #{context.table_name} (id, string) VALUES ($1, $2)", UUID.v7, "my string" + + # Now the table has 2 columns + wait_for { handler.relations.first.last.columns.size == 2 } + end + + it_consumes_wal "truncations" do |handler, context| + PG_DB.exec "CREATE TABLE #{context.table_name} (id UUID PRIMARY KEY DEFAULT gen_random_uuid())" + PG_DB.exec "INSERT INTO #{context.table_name} (id) VALUES ($1)", UUID.v7 + wait_for { handler.relations.any? } + + PG_DB.exec "TRUNCATE #{context.table_name}" + + wait_for { handler.truncations.any? } + end + end +else + Log.warn { "Skipping #{__FILE__}, set wal_level=logical in postgresql.conf to enable" } +end + +private def it_consumes_wal(name : String, **options, &block : TestMessageHandler, Context ->) + it "consumes #{name} from the WAL", **options do + context = Context.new + handler = TestMessageHandler.new + PG_DB.exec "CREATE PUBLICATION #{context.publication_name} FOR ALL TABLES" + PG_DB.exec "SELECT pg_create_logical_replication_slot($1, 'pgoutput')", context.slot_name + subscriber = PG.connect_replication DB_URL, + handler: handler, + publication_name: context.publication_name, + slot_name: context.slot_name + + begin + block.call handler, context + ensure + subscriber.try &.close + PG_DB.exec "SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots WHERE slot_name LIKE 'test_slot_%'" + PG_DB.query_each "SELECT DISTINCT pubname::text FROM pg_publication_tables WHERE schemaname = 'public' and pubname::text LIKE 'test_publication_%'" do |rs| + PG_DB.exec "DROP PUBLICATION #{rs.read(String)}" + end + PG_DB.query_each "SELECT tablename::text FROM pg_tables WHERE schemaname = 'public' and tablename LIKE 'test_table_%'" do |rs| + PG_DB.exec "DROP TABLE IF EXISTS #{rs.read(String)}" + end + PG_DB.query_each "SELECT typname::text FROM pg_type WHERE typname LIKE 'test_type_%'" do |rs| + PG_DB.exec "DROP TYPE IF EXISTS #{rs.read(String)}" + end + end + end +end + +private record Context, + table_name : String = "test_table_#{Random::Secure.hex}", + publication_name : String = "test_publication_#{Random::Secure.hex}", + slot_name : String = "test_slot_#{Random::Secure.hex}", + type_name : String = "test_type_#{Random::Secure.hex}" + +private def wait_for(condition = "the block to return truthy", timeout : Time::Span = 2.seconds, &) + start = Time.monotonic + until yield + if Time.monotonic - start > 2.seconds + raise "Timed out waiting for #{condition}" + end + sleep 5.milliseconds + end +end diff --git a/src/pg.cr b/src/pg.cr index 112b5d5..812d462 100644 --- a/src/pg.cr +++ b/src/pg.cr @@ -31,6 +31,10 @@ module PG ListenConnection.new(url, channels, blocking, &blk) end + def self.connect_replication(url, *, handler, publication_name, slot_name) + Replication::Connection.new(url, handler, publication_name: publication_name, slot_name: slot_name) + end + class ListenConnection @conn : PG::Connection diff --git a/src/pg/connection.cr b/src/pg/connection.cr index 1f22202..4465433 100644 --- a/src/pg/connection.cr +++ b/src/pg/connection.cr @@ -17,7 +17,7 @@ module PG super(options) begin - @connection.connect + @connection.connect(replication: @connection.conninfo.replication) rescue ex raise DB::ConnectionRefused.new(cause: ex) end @@ -95,6 +95,14 @@ module PG end end + protected def listen_replication(publication_name : String, slot_name : String, blocking : Bool = false, &block : Replication::Frame ->) + if blocking + @connection.start_replication_frame_loop(publication_name, slot_name, &block) + else + spawn { @connection.start_replication_frame_loop(publication_name, slot_name, &block) } + end + end + def version vers = connection.server_parameters["server_version"].partition(' ').first.split('.').map(&.to_i) {major: vers[0], minor: vers[1], patch: vers[2]? || 0} diff --git a/src/pg/record.cr b/src/pg/record.cr new file mode 100644 index 0000000..30435f4 --- /dev/null +++ b/src/pg/record.cr @@ -0,0 +1,43 @@ +struct PG::Record(*T) + getter data : T + + def self.read_from(reader : Reader) + new reader.read({{T.map(&.instance)}}) + end + + def initialize(@data) + end + + struct Reader + getter bytes : Bytes + getter size : Int32 + getter connection : Connection + + def initialize(@bytes, @size, @connection) + end + + def read(types : Tuple(*T)) forall T + io = ResultSet::Buffer.new(IO::Memory.new(@bytes), @bytes.size, @connection) + + {% begin %} + { + {% for type in T %} + read({{type}}, io).as({{type.instance}}), + {% end %} + } + {% end %} + end + + private def read(type : T.class, io : IO) : T forall T + oid = io.read_bytes(Int32, IO::ByteFormat::BigEndian) + size = io.read_bytes(Int32, IO::ByteFormat::BigEndian) + Decoders.from_oid(oid).decode(io, size, oid) + end + end +end + +class DB::ResultSet + def read(type : PG::Record.class) + type.read_from read(PG::Record::Reader) + end +end diff --git a/src/pg/replication.cr b/src/pg/replication.cr new file mode 100644 index 0000000..20f5028 --- /dev/null +++ b/src/pg/replication.cr @@ -0,0 +1,204 @@ +require "./replication/frame" +require "./replication/x_log_data" +require "./replication/copy_data" +require "./replication/standby_status_update" + +module PG::Replication + module Handler + # This method must be defined in order to tell the `Connection` how much of + # the WAL has been flushed and applied. + # + # ``` + # connection = PG.listen_replication db_url, + # handler: MyHandler.new, + # publication_name: "my_publication", + # slot_name: "my_replication_slot" + # + # class MyHandler + # include PG::Replication::Handler + # + # def received(data : PG::Replication::XLogData, connection : PG::Replication::Connection, &) + # yield + # connection.last_wal_byte_flushed = data.wal_end + # if data.message.is_a? PG::Replication::Commit + # connection.last_wal_byte_applied = data.wal_end + # end + # end + # end + # ``` + abstract def received(data : PG::Replication::XLogData, connection : PG::Replication::Connection, &) + + # Override this method with any of the `WALMessage` subclasses to handle + # receiving replication messages of that type. + # + # ``` + # class MyHandler + # include PG::Replication::Handler + # + # # Using some hypothetical Kafka client for CDC + # def initialize(@kafka : Kafka::Client) + # end + # + # def received(insert : PG::Replication::Insert) + # @kafka.publisher.publish({ + # oid: insert.oid, + # data: insert.tuple_data, + # }.to_msgpack) + # end + # end + # ``` + def received(frame) + end + end + + class Connection + getter handler : Handler + getter publication_name : String + getter slot_name : String + getter last_wal_byte_received = 0i64 + property last_wal_byte_flushed = 0i64 + property last_wal_byte_applied = 0i64 + getter? closed = false + + # :nodoc: + def initialize(uri : URI | String, @handler, *, @publication_name, @slot_name, blocking : Bool = false) + if uri.is_a? String + uri = URI.parse(uri) + else + uri = uri.dup + end + query_params = uri.query_params + query_params["replication"] = "database" + uri.query_params = query_params + @conn = DB.connect(uri).as(PG::Connection) + @conn.listen_replication( + publication_name: publication_name, + slot_name: slot_name, + blocking: blocking, + ) do |frame| + received frame + end + + spawn do + until closed? + sleep 10.seconds + begin + send_keepalive + rescue ex : IO::Error + break if closed? + raise ex + end + end + end + end + + # :nodoc: + def received(frame : CopyBoth) + end + + # :nodoc: + def received(frame : CopyData) + received frame.data + end + + def received(frame : ErrorFrame) + end + + # Handle the `XLogData` message that wraps `WALMessage`s + def received(data : XLogData) + @last_wal_byte_received = data.wal_end + handler.received data, self do + handler.received data.message + end + end + + # :nodoc: + def received(keepalive : KeepAlive) + send_keepalive if keepalive.response_expected? + end + + # This shouldn't ever be received, but it can be represented in memory so we + # need to include it for completeness. + def received(response : KeepAliveResponse) + raise NotImplementedError.new("KeepAliveResponses are intended to be sent, not received") + end + + def close + return if closed? + # We attempt to send off one last keepalive to let the server know where + # we left off. + send_keepalive + @conn.close + ensure + @closed = true + end + + def send_keepalive + write { send_keepalive! } + end + + private def send_keepalive! : Nil + CopyData.new( + StandbyStatusUpdate.new( + last_wal_byte_received: last_wal_byte_received, + last_wal_byte_flushed: last_wal_byte_flushed, + last_wal_byte_applied: last_wal_byte_applied, + ) + ).to_io socket + + flush + end + + private def write(&) + write_mutex.synchronize { yield } + end + + private getter write_mutex = Mutex.new + + private def flush + socket.flush + end + + private def socket + @conn.connection.soc + end + end + + # TODO: Implement the types below in order to support higher `proto_version`s + + # # StreamStart requires proto_version >= 2 + # struct StreamStart < WALMessage + # end + + # # StreamStop requires proto_version >= 2 + # struct StreamStop < WALMessage + # end + + # # StreamCommit requires proto_version >= 2 + # struct StreamCommit < WALMessage + # end + + # # StreamAbort requires proto_version >= 2 + # struct StreamAbort < WALMessage + # end + + # # BeginPrepare requires proto_version >=3 + # struct BeginPrepare < WALMessage + # end + + # # Prepare requires proto_version >=3 + # struct Prepare < WALMessage + # end + + # # CommitPrepared requires proto_version >=3 + # struct CommitPrepared < WALMessage + # end + + # # RollbackPrepared requires proto_version >=3 + # struct RollbackPrepared < WALMessage + # end + + # # StreamPrepare requires proto_version >=3 + # struct StreamPrepare < WALMessage + # end +end diff --git a/src/pg/replication/begin.cr b/src/pg/replication/begin.cr new file mode 100644 index 0000000..a04ba65 --- /dev/null +++ b/src/pg/replication/begin.cr @@ -0,0 +1,16 @@ +require "./wal_message" +require "./time_parser" + +module PG::Replication + struct Begin < WALMessage + getter final_lsn : Int64 + getter timestamp : Time + getter transaction_id : Int32 + + def initialize(io : IO) + @final_lsn = read(io, Int64) + @timestamp = TimeParser.call(io) + @transaction_id = read(io, Int32) + end + end +end diff --git a/src/pg/replication/commit.cr b/src/pg/replication/commit.cr new file mode 100644 index 0000000..6c78ba8 --- /dev/null +++ b/src/pg/replication/commit.cr @@ -0,0 +1,17 @@ +require "./wal_message" + +module PG::Replication + struct Commit < WALMessage + getter flags : Int8 + getter begin_lsn : Int64 + getter end_lsn : Int64 + getter timestamp : Time + + def initialize(io : IO) + @flags = read(io, Int8) + @begin_lsn = read(io, Int64) + @end_lsn = read(io, Int64) + @timestamp = read_time(io) + end + end +end diff --git a/src/pg/replication/copy_both.cr b/src/pg/replication/copy_both.cr new file mode 100644 index 0000000..63fd8e9 --- /dev/null +++ b/src/pg/replication/copy_both.cr @@ -0,0 +1,24 @@ +require "./frame" + +module PG::Replication + struct CopyBoth < Frame + getter format : Format + getter column_formats : Array(Int16) + + def initialize(io : IO) + size = io.read_bytes(Int32, IO::ByteFormat::NetworkEndian) + sized = IO::Sized.new(io, size - 4) + @format = Format.new(sized.read_bytes(Int8, IO::ByteFormat::NetworkEndian)) + column_count = sized.read_bytes(Int16, IO::ByteFormat::NetworkEndian) + @column_formats = Array.new(column_count) do + sized.read_bytes(Int16, IO::ByteFormat::NetworkEndian) + end + sized.close + end + + enum Format : Int8 + Text = 0 + Binary = 1 + end + end +end diff --git a/src/pg/replication/copy_data.cr b/src/pg/replication/copy_data.cr new file mode 100644 index 0000000..4354889 --- /dev/null +++ b/src/pg/replication/copy_data.cr @@ -0,0 +1,34 @@ +module PG::Replication + struct CopyData < Frame + getter data : XLogData | KeepAlive | KeepAliveResponse + + def initialize(io : IO) + size = io.read_bytes(Int32, IO::ByteFormat::NetworkEndian) + case byte = io.read_byte + when Nil + raise IO::EOFError.new("Connection was unexpectedly terminated") + when 'w' + @data = XLogData.new(io) + when 'k' + @data = KeepAlive.new(io) + else + raise Error.new("Unexpected CopyData byte marker: 0x#{byte.to_s(16)} (#{byte.chr.inspect})") + end + end + + def initialize(@data) + end + + def to_io(io : IO) : Nil + buffer = IO::Memory.new + io << 'd' + payload = IO::Memory.new.tap { |buf| data.to_io buf }.to_slice + io.write_bytes payload.bytesize + 4, IO::ByteFormat::NetworkEndian + io.write payload + end + end +end + +require "./x_log_data" +require "./keep_alive" +require "./keep_alive_response" diff --git a/src/pg/replication/delete.cr b/src/pg/replication/delete.cr new file mode 100644 index 0000000..03a4ec0 --- /dev/null +++ b/src/pg/replication/delete.cr @@ -0,0 +1,25 @@ +require "./wal_message" + +module PG::Replication + struct Delete < WALMessage + # getter transaction_id : Int32 + getter oid : Int32 + getter key_tuple_data : TupleData? + getter old_tuple_data : TupleData? + + def initialize(io : IO) + # @transaction_id = read(io, Int32) # Requires proto_version >= 2 + @oid = read(io, Int32) + case byte = io.read_byte + when nil + raise IO::EOFError.new("Connection was unexpectedly terminated") + when 'K' + @key_tuple_data = read_tuple_data(io) + when 'O' + @old_tuple_data = read_tuple_data(io) + else + raise Error.new("Expected new TupleData byte marker, got: 0x#{byte.to_s(16)} (#{byte.chr.inspect})") + end + end + end +end diff --git a/src/pg/replication/error_frame.cr b/src/pg/replication/error_frame.cr new file mode 100644 index 0000000..a4449b0 --- /dev/null +++ b/src/pg/replication/error_frame.cr @@ -0,0 +1,36 @@ +module PG::Replication + struct ErrorFrame < Frame + getter severity : Severity? + getter message : String? + getter detail : String? + getter hint : String? + getter misc = [] of {Char, String} + + def initialize(io : IO) + size = read(io, Int32) + pp size: size + loop do + case byte = read(io, UInt8) + when 0 then return + when 'S' then @severity = Severity.parse(read_string(io)) + when 'M' then @message = read_string(io) + when 'D' then @detail = read_string(io) + when 'H' then @hint = read_string(io) + else + @misc << {byte.chr, read_string(io)} + end + end + end + end + + enum Severity + LOG + INFO + DEBUG + NOTICE + WARNING + ERROR + FATAL + PANIC + end +end diff --git a/src/pg/replication/frame.cr b/src/pg/replication/frame.cr new file mode 100644 index 0000000..994f145 --- /dev/null +++ b/src/pg/replication/frame.cr @@ -0,0 +1,26 @@ +require "./read" + +module PG::Replication + abstract struct Frame + include Read + + def self.from_io(io : IO) : self + case byte = io.read_byte + when nil + raise IO::EOFError.new("Connection was unexpectedly terminated") + when 'W' + CopyBoth.new(io) + when 'd' + CopyData.new(io) + when 'E' + ErrorFrame.new(io) + else + raise Error.new("Unexpected byte marker: 0x#{byte.to_s(16)} (#{byte.chr.inspect})") + end + end + end +end + +require "./copy_both" +require "./copy_data" +require "./error_frame" diff --git a/src/pg/replication/insert.cr b/src/pg/replication/insert.cr new file mode 100644 index 0000000..81b6afd --- /dev/null +++ b/src/pg/replication/insert.cr @@ -0,0 +1,18 @@ +require "./wal_message" + +module PG::Replication + struct Insert < WALMessage + # getter transaction_id : Int32 + getter oid : Int32 + getter tuple_data : TupleData + + def initialize(io : IO) + # Only included in WAL protocol v2+ + # @transaction_id = read(io, Int32) + @oid = read(io, Int32) + # This is an 'N' indicating a new tuple + io.read_byte + @tuple_data = read_tuple_data(io) + end + end +end diff --git a/src/pg/replication/keep_alive.cr b/src/pg/replication/keep_alive.cr new file mode 100644 index 0000000..cd3a902 --- /dev/null +++ b/src/pg/replication/keep_alive.cr @@ -0,0 +1,19 @@ +require "./time_parser" + +module PG::Replication + struct KeepAlive + getter wal_end : Int64 + getter timestamp : Time + getter? response_expected : Bool + + def initialize(io : IO) + @wal_end = io.read_bytes(Int64, IO::ByteFormat::NetworkEndian) + @timestamp = TimeParser.call(io) + @response_expected = io.read_bytes(Int8, IO::ByteFormat::NetworkEndian) == 1 + end + + def to_io(io : IO) : Nil + raise NotImplementedError.new("KeepAlives are intended to be received from the server, not sent to it") + end + end +end diff --git a/src/pg/replication/keep_alive_response.cr b/src/pg/replication/keep_alive_response.cr new file mode 100644 index 0000000..a40af1f --- /dev/null +++ b/src/pg/replication/keep_alive_response.cr @@ -0,0 +1,4 @@ +module PG::Replication + abstract struct KeepAliveResponse + end +end diff --git a/src/pg/replication/message.cr b/src/pg/replication/message.cr new file mode 100644 index 0000000..2552741 --- /dev/null +++ b/src/pg/replication/message.cr @@ -0,0 +1,25 @@ +require "./wal_message" + +module PG::Replication + struct Message < WALMessage + # Requires proto_version >= 2 + # getter transaction_id : Int32 + getter flags : Flags + getter lsn : Int64 + getter prefix : String + getter content : Bytes + + def initialize(io : IO) + # @transaction_id = read(io, Int32) # Requires proto_version >= 2 + @flags = Flags.new(read(io, Int8)) + @lsn = read(io, Int64) + @prefix = read_string(io) + @content = read_bytes(io) + end + + @[::Flags] + enum Flags : Int8 + Transactional + end + end +end diff --git a/src/pg/replication/origin.cr b/src/pg/replication/origin.cr new file mode 100644 index 0000000..1761ba2 --- /dev/null +++ b/src/pg/replication/origin.cr @@ -0,0 +1,13 @@ +require "./wal_message" + +module PG::Replication + struct Origin < WALMessage + getter lsn : Int64 + getter name : String + + def initialize(io : IO) + @lsn = read(io, Int64) + @name = read_string(io) + end + end +end diff --git a/src/pg/replication/read.cr b/src/pg/replication/read.cr new file mode 100644 index 0000000..216f7b4 --- /dev/null +++ b/src/pg/replication/read.cr @@ -0,0 +1,23 @@ +require "./time_parser" + +module PG::Replication + module Read + protected def read(io : IO, int : Int.class) + io.read_bytes int, IO::ByteFormat::NetworkEndian + end + + protected def read_string(io : IO) : String + io.read_line('\0', chomp: true) + end + + protected def read_bytes(io : IO) : Bytes + bytes = Bytes.new(read(io, Int32)) + io.read_fully bytes + bytes + end + + protected def read_time(io : IO) : Time + TimeParser.call(io) + end + end +end diff --git a/src/pg/replication/relation.cr b/src/pg/replication/relation.cr new file mode 100644 index 0000000..d5d7be8 --- /dev/null +++ b/src/pg/replication/relation.cr @@ -0,0 +1,40 @@ +require "./wal_message" + +module PG::Replication + struct Relation < WALMessage + # Requires proto_version >= 2 + # getter transaction_id : Int32 + getter oid : Int32 + getter namespace : String + getter name : String + getter replica_identity : Int8 + getter columns : Array(Column) + + def initialize(io : IO) + # @transaction_id = read(io, Int32) + @oid = read(io, Int32) + @namespace = read_string(io) + @name = read_string(io) + @replica_identity = read(io, Int8) + column_count = read(io, Int16) + @columns = Array.new(column_count) do + Column.new( + flags: Flags.new(read(io, Int8)), + name: read_string(io), + oid: read(io, Int32), + type_modifier: read(io, Int32), + ) + end + end + + record Column, + flags : Flags, + name : String, + oid : Int32, + type_modifier : Int32 + @[::Flags] + enum Flags + Key = 1 + end + end +end diff --git a/src/pg/replication/standby_status_update.cr b/src/pg/replication/standby_status_update.cr new file mode 100644 index 0000000..b8bc6af --- /dev/null +++ b/src/pg/replication/standby_status_update.cr @@ -0,0 +1,34 @@ +require "./keep_alive_response" + +module PG::Replication + struct StandbyStatusUpdate < KeepAliveResponse + getter last_wal_byte_received : Int64 + getter last_wal_byte_flushed : Int64 + getter last_wal_byte_applied : Int64 + getter timestamp : Time + getter? response_expected : Bool + + def initialize( + @last_wal_byte_received, + @last_wal_byte_flushed, + @last_wal_byte_applied, + *, + @timestamp = Time.utc, + @response_expected = false, + ) + end + + def to_io(io : IO) : Nil + io << 'r' + write io, last_wal_byte_received + write io, last_wal_byte_flushed + write io, last_wal_byte_applied + write io, (@timestamp - Time.utc(2000, 1, 1)).total_microseconds.to_i64 + write io, response_expected? ? 1u8 : 0u8 + end + + private def write(io, value) + io.write_bytes value, IO::ByteFormat::NetworkEndian + end + end +end diff --git a/src/pg/replication/time_parser.cr b/src/pg/replication/time_parser.cr new file mode 100644 index 0000000..c0f12ef --- /dev/null +++ b/src/pg/replication/time_parser.cr @@ -0,0 +1,13 @@ +module PG::Replication + private module TimeParser + extend self + + def call(io : IO) : Time + call io.read_bytes(Int64, IO::ByteFormat::NetworkEndian) + end + + def call(microseconds : Int64) + Time.utc(2000, 1, 1) + microseconds.microseconds + end + end +end diff --git a/src/pg/replication/truncate.cr b/src/pg/replication/truncate.cr new file mode 100644 index 0000000..cdfb890 --- /dev/null +++ b/src/pg/replication/truncate.cr @@ -0,0 +1,26 @@ +require "./wal_message" + +module PG::Replication + struct Truncate < WALMessage + # transaction_id requires proto_version >= 2 + # getter transaction_id : Int32 + getter options : Options + getter relation_oids : Array(Int32) + + def initialize(io : IO) + # @transaction_id = read(io, Int32) # Requires proto_version >= 2 + relation_count = read(io, Int32) + @options = Options.new(read(io, Int8)) + @relation_oids = Array.new(relation_count) do + read(io, Int32) + end + end + + @[Flags] + enum Options : Int8 + NONE = 0 + CASCADE = 1 + RESTART_IDENTITY = 2 + end + end +end diff --git a/src/pg/replication/type.cr b/src/pg/replication/type.cr new file mode 100644 index 0000000..bc62f2d --- /dev/null +++ b/src/pg/replication/type.cr @@ -0,0 +1,16 @@ +require "./wal_message" + +module PG::Replication + struct Type < WALMessage + # getter transaction_id : Int32 + getter oid : Int32 + getter namespace : String + getter data_type : String + + def initialize(io : IO) + @oid = read(io, Int32) + @namespace = read_string(io) + @data_type = read_string(io) + end + end +end diff --git a/src/pg/replication/update.cr b/src/pg/replication/update.cr new file mode 100644 index 0000000..213be7e --- /dev/null +++ b/src/pg/replication/update.cr @@ -0,0 +1,36 @@ +require "./wal_message" + +module PG::Replication + struct Update < WALMessage + getter oid : Int32 + getter key_tuple_data : TupleData? + getter old_tuple_data : TupleData? + getter new_tuple_data : TupleData + + def initialize(io : IO) + @oid = read(io, Int32) + submessage_type = read(io, UInt8) + case submessage_type + when 'K' + @key_tuple_data = read_tuple_data(io) + when 'O' + @old_tuple_data = read_tuple_data(io) + when 'N' + new_tuple_data = read_tuple_data(io) + end + + # If either 'K' or 'O' were specified above, then the next value is our + # new tuple. Otherwise, that was our new tuple, so we just assign it. + if new_tuple_data + @new_tuple_data = new_tuple_data + else + case byte = read(io, UInt8) + when 'N' + @new_tuple_data = read_tuple_data(io) + else + raise Error.new("Expected new TupleData byte marker, got: 0x#{byte.to_s(16)} (#{byte.chr.inspect})") + end + end + end + end +end diff --git a/src/pg/replication/wal_message.cr b/src/pg/replication/wal_message.cr new file mode 100644 index 0000000..bb1495c --- /dev/null +++ b/src/pg/replication/wal_message.cr @@ -0,0 +1,74 @@ +require "./read" + +module PG::Replication + abstract struct WALMessage + include Read + + def self.new(io : IO) : self + {% if @type != PG::Replication::WALMessage %} + {% raise "Must implement #{@type}#initialize(io : IO)" %} + {% end %} + + case byte = io.read_byte + when Nil + raise IO::EOFError.new("Connection was unexpectedly terminated") + when 'B' + Begin.new(io) + when 'C' + Commit.new(io) + when 'R' + Relation.new(io) + when 'I' + Insert.new(io) + when 'U' + Update.new(io) + when 'D' + Delete.new(io) + when 'T' + Truncate.new(io) + when 'Y' + Type.new(io) + when 'M' + Message.new(io) + when 'O' + Origin.new(io) + else + raise Error.new("Unexpected WAL message byte marker: 0x#{byte.to_s(16)} (#{byte.chr.inspect})") + end + end + + protected def read_tuple_data(io) : TupleData + column_count = read(io, Int16) + Array.new(column_count) do + case byte = io.read_byte + when Nil + raise IO::EOFError.new("Connection was unexpectedly terminated") + when 'n' + nil + when 'u' + UnchangedTOASTValue.new + when 't' + read_string(io) + when 'b' + read_bytes(io) + else + raise Error.new("Unexpected TupleData byte marker: 0x#{byte.to_s(16)} (#{byte.chr.inspect})") + end + end + end + + alias TupleData = Array(UnchangedTOASTValue | Bytes | String | Nil) + record UnchangedTOASTValue + end +end + +require "./begin" +require "./commit" +require "./relation" +require "./insert" +require "./update" +require "./delete" +require "./truncate" +require "./type" +require "./message" +require "./origin" diff --git a/src/pg/replication/x_log_data.cr b/src/pg/replication/x_log_data.cr new file mode 100644 index 0000000..57e2e39 --- /dev/null +++ b/src/pg/replication/x_log_data.cr @@ -0,0 +1,22 @@ +require "./wal_message" +require "./time_parser" + +module PG::Replication + struct XLogData + getter wal_start : Int64 + getter wal_end : Int64 + getter timestamp : Time + getter message : WALMessage + + def initialize(io : IO) + @wal_start = io.read_bytes(Int64, IO::ByteFormat::NetworkEndian) + @wal_end = io.read_bytes(Int64, IO::ByteFormat::NetworkEndian) + @timestamp = TimeParser.call(io) + @message = WALMessage.new(io) + end + + def to_io(io : IO) : Nil + raise NotImplementedError.new("XLogData messages are meant to be sent by the server, not received by the client") + end + end +end diff --git a/src/pq/connection.cr b/src/pq/connection.cr index c226ddd..845cce5 100644 --- a/src/pq/connection.cr +++ b/src/pq/connection.cr @@ -15,6 +15,7 @@ module PQ class Connection getter soc : UNIXSocket | TCPSocket | OpenSSL::SSL::Socket::Client getter server_parameters = Hash(String, String).new + getter conninfo : ConnInfo property notice_handler = Proc(Notice, Void).new { } property notification_handler = Proc(Notification, Void).new { } @mutex = Mutex.new @@ -144,7 +145,7 @@ module PQ soc.skip(count) end - def startup(args) + def startup(args : Array(String)) len = args.reduce(0) { |acc, arg| acc + arg.size + 1 } write_i32 len + 8 + 1 write_i32 0x30000 @@ -188,6 +189,19 @@ module PQ end end + def start_replication_frame_loop(publication_name : String, slot_name : String, &block : PG::Replication::Frame ->) + command = "START_REPLICATION SLOT #{slot_name} LOGICAL 0/0 (proto_version '1', binary 'true', publication_names '#{publication_name}')" + send_query_message command + loop do + break if soc.closed? + block.call PG::Replication::Frame.from_io(soc) + rescue e : IO::Error + raise e unless soc.closed? + rescue e + Log.error(exception: e) { } + end + end + private def read_one_frame(frame_type) size = read_i32 slice = read_bytes(size - 4) @@ -195,38 +209,26 @@ module PQ end private def handle_async_frames(frame) - if frame.is_a?(Frame::ErrorResponse) - handle_error frame - true - elsif frame.is_a?(Frame::NotificationResponse) - handle_notification frame - true - elsif frame.is_a?(Frame::NoticeResponse) - handle_notice frame - true - elsif frame.is_a?(Frame::ParameterStatus) - handle_parameter frame - true - else - false - end + false end - private def handle_error(error_frame : Frame::ErrorResponse) + private def handle_async_frames(error_frame : Frame::ErrorResponse) expect_frame Frame::ReadyForQuery if @established notice_handler.call(error_frame.as_notice) raise PQError.new(error_frame.fields) end - private def handle_notice(frame : Frame::NoticeResponse) - notice_handler.call(frame.as_notice) + private def handle_async_frames(frame : Frame::NotificationResponse) + notification_handler.call(frame.as_notification) + true end - private def handle_notification(frame : Frame::NotificationResponse) - notification_handler.call(frame.as_notification) + private def handle_async_frames(frame : Frame::NoticeResponse) + notice_handler.call(frame.as_notice) + true end - private def handle_parameter(frame : Frame::ParameterStatus) + private def handle_async_frames(frame : Frame::ParameterStatus) @server_parameters[frame.key] = frame.value case frame.key when "client_encoding" @@ -239,18 +241,21 @@ module PQ raise ConnectionError.new( "Only on is supported for integer_datetimes, got: #{frame.value.inspect}") end - else - # ignore end + + true end - def connect + def connect(*, replication : String? = nil) startup_args = [ "user", @conninfo.user, "database", @conninfo.database, "application_name", @conninfo.application_name, "client_encoding", "utf8", ] + if replication + startup_args << "replication" << replication + end startup startup_args diff --git a/src/pq/conninfo.cr b/src/pq/conninfo.cr index 6e1b922..a5d77ce 100644 --- a/src/pq/conninfo.cr +++ b/src/pq/conninfo.cr @@ -38,10 +38,12 @@ module PQ # The application name. Optional (defaults to "crystal"). getter application_name : String + getter replication : String? + getter auth_methods : Array(String) = %w[scram-sha-256-plus scram-sha-256 md5] # Create a new ConnInfo from all parts - def initialize(host : String? = nil, database : String? = nil, user : String? = nil, password : String? = nil, port : Int | String? = nil, sslmode : String | Symbol? = nil, application_name : String? = nil) + def initialize(host : String? = nil, database : String? = nil, user : String? = nil, password : String? = nil, port : Int | String? = nil, sslmode : String | Symbol? = nil, application_name : String? = nil, @replication = nil) @host = default_host host db = default_database database @database = db.lchop('/') @@ -77,7 +79,7 @@ module PQ def initialize(uri : URI) params = URI::Params.parse(uri.query.to_s) hostname = uri.hostname.presence || params.fetch("host", "") - initialize(hostname, uri.path, uri.user, uri.password, uri.port, :prefer, params.fetch("application_name", nil)) + initialize(hostname, uri.path, uri.user, uri.password, uri.port, :prefer, params.fetch("application_name", nil), params["replication"]?) if q = uri.query HTTP::Params.parse(q) do |key, value| handle_sslparam(key, value) @@ -89,10 +91,10 @@ module PQ # # Valid keys match Postgres "conninfo" keys and are `"host"`, `"dbname"`, # `"user"`, `"password"`, `"port"`, `"sslmode"`, `"sslcert"`, `"sslkey"`, - # `"sslrootcert"` and `"application_name"`. + # `"sslrootcert"`, `"application_name"`, and `"replication"`. def initialize(params : Hash) initialize(params["host"]?, params["dbname"]?, params["user"]?, - params["password"]?, params["port"]?, params["sslmode"]?, params["application_name"]?) + params["password"]?, params["port"]?, params["sslmode"]?, params["application_name"]?, params["replication"]?) params.each do |key, value| handle_sslparam(key, value) end