From a2f11afd843297ff26b4e68a7fcde524f2523a30 Mon Sep 17 00:00:00 2001 From: Garet Halliday <ghalliday@gfxlabs.io> Date: Wed, 7 Sep 2022 13:55:10 -0500 Subject: [PATCH] pool connections --- cmd/cgat/main.go | 10 +- lib/gat/client.go | 1021 +--------------- lib/gat/conn_pool.go | 13 + lib/gat/gat.go | 5 + lib/gat/gatling.go | 202 ---- lib/gat/{ => gatling/admin}/admin.go | 22 +- lib/gat/{ => gatling/admin}/admin_test.go | 2 +- lib/gat/gatling/client/client.go | 1052 +++++++++++++++++ lib/gat/{ => gatling/client}/client_test.go | 2 +- lib/gat/gatling/conn_pool/conn_pool.go | 69 ++ lib/gat/gatling/gatling.go | 142 +++ lib/gat/{ => gatling/messages}/messages.go | 2 +- .../{ => gatling/messages}/messages_test.go | 2 +- lib/gat/gatling/pool/pool.go | 544 +++++++++ lib/gat/{ => gatling/pool}/pool_test.go | 2 +- .../query_router}/query_router.go | 9 +- .../query_router}/query_router_test.go | 2 +- lib/gat/{ => gatling/server}/server.go | 42 +- lib/gat/{ => gatling/server}/server_test.go | 7 +- lib/gat/{ => gatling/sharding}/sharding.go | 2 +- .../{ => gatling/sharding}/sharding_test.go | 2 +- lib/gat/{ => gatling/stats}/stats.go | 2 +- lib/gat/pool.go | 566 +-------- lib/gat/{ => protocol/pg_error}/error.go | 12 +- 24 files changed, 1909 insertions(+), 1825 deletions(-) create mode 100644 lib/gat/conn_pool.go create mode 100644 lib/gat/gat.go delete mode 100644 lib/gat/gatling.go rename lib/gat/{ => gatling/admin}/admin.go (96%) rename lib/gat/{ => gatling/admin}/admin_test.go (81%) create mode 100644 lib/gat/gatling/client/client.go rename lib/gat/{ => gatling/client}/client_test.go (76%) create mode 100644 lib/gat/gatling/conn_pool/conn_pool.go create mode 100644 lib/gat/gatling/gatling.go rename lib/gat/{ => gatling/messages}/messages.go (99%) rename lib/gat/{ => gatling/messages}/messages_test.go (79%) create mode 100644 lib/gat/gatling/pool/pool.go rename lib/gat/{ => gatling/pool}/pool_test.go (77%) rename lib/gat/{ => gatling/query_router}/query_router.go (97%) rename lib/gat/{ => gatling/query_router}/query_router_test.go (99%) rename lib/gat/{ => gatling/server}/server.go (84%) rename lib/gat/{ => gatling/server}/server_test.go (65%) rename lib/gat/{ => gatling/sharding}/sharding.go (99%) rename lib/gat/{ => gatling/sharding}/sharding_test.go (99%) rename lib/gat/{ => gatling/stats}/stats.go (99%) rename lib/gat/{ => protocol/pg_error}/error.go (99%) diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 3a5407fd..078a1b07 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -3,7 +3,7 @@ package main import ( "context" "gfx.cafe/gfx/pggat/lib/config" - "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/gatling" "git.tuxpa.in/a/zlog/log" ) @@ -15,14 +15,10 @@ func main() { if err != nil { panic(err) } - gatling := gat.NewGatling() - err = gatling.ApplyConfig(conf) - if err != nil { - panic(err) - } + g := gatling.NewGatling(conf) log.Println("listening on port", conf.General.Port) - err = gatling.ListenAndServe(context.Background()) + err = g.ListenAndServe(context.Background()) if err != nil { panic(err) } diff --git a/lib/gat/client.go b/lib/gat/client.go index 98d548f8..a13d0cf9 100644 --- a/lib/gat/client.go +++ b/lib/gat/client.go @@ -1,1025 +1,10 @@ package gat import ( - "bufio" - "bytes" - "context" - "crypto/rand" - "crypto/tls" - "fmt" - "io" - "math/big" - "net" - "reflect" - "gfx.cafe/gfx/pggat/lib/gat/protocol" - "gfx.cafe/gfx/pggat/lib/util/maps" - - "gfx.cafe/gfx/pggat/lib/config" - "git.tuxpa.in/a/zlog" - "git.tuxpa.in/a/zlog/log" - "github.com/ethereum/go-ethereum/common/math" ) -type ClientKey [2]int - -type ClientInfo struct { - A int - B int - C string - D uint16 -} - -// / client state, one per client -type Client struct { - conn net.Conn - r *bufio.Reader - wr io.Writer - - buf bytes.Buffer - - addr net.Addr - - cancel_mode bool - txn_mode bool - - pid int32 - secret_key int32 - - parameters map[string]string - stats any // TODO: Reporter - admin bool - - server *Server - - last_addr_id int - last_srv_id int - - pool_name string - username string - - conf *config.Global - - log zlog.Logger -} - -func NewClient( - conf *config.Global, - conn net.Conn, - admin_only bool, -) *Client { - c := &Client{ - conn: conn, - r: bufio.NewReader(conn), - wr: conn, - addr: conn.RemoteAddr(), - conf: conf, - } - c.log = log.With(). - Stringer("clientaddr", c.addr).Logger() - return c -} - -func (c *Client) Accept(ctx context.Context) error { - // read a packet - startup := new(protocol.StartupMessage) - err := startup.Read(c.r) - if err != nil { - return err - } - switch startup.Fields.ProtocolVersionNumber { - case 196608: - case 80877102: - return c.handle_cancel(ctx, startup) - case 80877103: - // ssl stuff now - useSsl := (c.conf.General.TlsCertificate != "") - if !useSsl { - _, err = protocol.WriteByte(c.wr, 'N') - if err != nil { - return err - } - startup = new(protocol.StartupMessage) - err = startup.Read(c.r) - if err != nil { - return err - } - } else { - _, err = protocol.WriteByte(c.wr, 'S') - if err != nil { - return err - } - //TODO: we need to do an ssl handshake here. - var cert tls.Certificate - cert, err = tls.LoadX509KeyPair(c.conf.General.TlsCertificate, c.conf.General.TlsPrivateKey) - if err != nil { - return err - } - cfg := &tls.Config{ - Certificates: []tls.Certificate{cert}, - InsecureSkipVerify: true, - } - c.conn = tls.Server(c.conn, cfg) - c.r = bufio.NewReader(c.conn) - c.wr = c.conn - err = startup.Read(c.r) - if err != nil { - return err - } - } - } - params := make(map[string]string) - for _, v := range startup.Fields.Parameters { - params[v.Name] = v.Value - } - - var ok bool - c.pool_name, ok = params["database"] - if !ok { - return &PostgresError{ - Severity: Fatal, - Code: InvalidAuthorizationSpecification, - Message: "param database required", - } - } - - c.username, ok = params["user"] - if !ok { - return &PostgresError{ - Severity: Fatal, - Code: InvalidAuthorizationSpecification, - Message: "param user required", - } - } - - c.admin = (c.pool_name == "pgcat" || c.pool_name == "pgbouncer") - - if c.conf.General.AdminOnly && !c.admin { - c.log.Debug().Msg("rejected non admin, since admin only mode") - return &PostgresError{ - Severity: Fatal, - Code: InvalidAuthorizationSpecification, - Message: "rejected non admin", - } - } - - pid, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt32)) - if err != nil { - return err - } - c.pid = int32(pid.Int64()) - skey, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt32)) - if err != nil { - return err - } - - c.secret_key = int32(skey.Int64()) - // TODO: Add SASL support. - - // Perform MD5 authentication. - pkt, salt, err := CreateMd5Challenge() - if err != nil { - return err - } - _, err = pkt.Write(c.wr) - if err != nil { - return err - } - - var rsp protocol.Packet - rsp, err = protocol.ReadFrontend(c.r) - if err != nil { - return err - } - var passwordResponse []byte - switch r := rsp.(type) { - case *protocol.AuthenticationResponse: - passwordResponse = r.Fields.Data - default: - return &PostgresError{ - Severity: Fatal, - Code: InvalidAuthorizationSpecification, - Message: fmt.Sprintf("wanted AuthenticationResponse packet, got '%+v'", rsp), - } - } - - pool, ok := c.conf.Pools[c.pool_name] - if !ok { - return &PostgresError{ - Severity: Fatal, - Code: InvalidAuthorizationSpecification, - Message: "no such pool", - } - } - _, user, ok := maps.FirstWhere(pool.Users, func(_ string, user config.User) bool { - return user.Name == c.username - }) - if !ok { - return &PostgresError{ - Severity: Fatal, - Code: InvalidPassword, - Message: "user not found", - } - } - - // Authenticate admin user. - if c.admin { - pw_hash := Md5HashPassword(c.conf.General.AdminUsername, c.conf.General.AdminPassword, salt[:]) - if !reflect.DeepEqual(pw_hash, passwordResponse) { - return &PostgresError{ - Severity: Fatal, - Code: InvalidPassword, - Message: "invalid password", - } - } - } else { - pw_hash := Md5HashPassword(c.username, user.Password, salt[:]) - if !reflect.DeepEqual(pw_hash, passwordResponse) { - return &PostgresError{ - Severity: Fatal, - Code: InvalidPassword, - Message: "invalid password", - } - } - } - - shard := pool.Shards["0"] - serv := shard.Servers[0] - c.server, err = DialServer(context.TODO(), fmt.Sprintf("%s:%d", serv.Host(), serv.Port()), &user, shard.Database, nil) - if err != nil { - return err - } - - c.log.Debug().Msg("Password authentication successful") - authOk := new(protocol.Authentication) - authOk.Fields.Code = 0 - _, err = authOk.Write(c.wr) - if err != nil { - return err - } - - // - for _, inf := range c.server.server_info { - _, err = inf.Write(c.wr) - if err != nil { - return err - } - } - backendKeyData := new(protocol.BackendKeyData) - backendKeyData.Fields.ProcessID = c.pid - backendKeyData.Fields.SecretKey = c.secret_key - _, err = backendKeyData.Write(c.wr) - if err != nil { - return err - } - readyForQuery := new(protocol.ReadyForQuery) - readyForQuery.Fields.Status = byte('I') - _, err = readyForQuery.Write(c.wr) - if err != nil { - return err - } - c.log.Debug().Msg("Ready for Query") - open := true - for open { - open, err = c.tick(ctx) - if err != nil { - return err - } - } - return nil -} - -// TODO: we need to keep track of queries so we can handle cancels -func (c *Client) handle_cancel(ctx context.Context, p *protocol.StartupMessage) error { - log.Println("cancel msg", p) - return nil -} - -// reads a packet from stream and handles it -func (c *Client) tick(ctx context.Context) (bool, error) { - rsp, err := protocol.ReadFrontend(c.r) - if err != nil { - return true, err - } - switch cast := rsp.(type) { - case *protocol.Describe: - case *protocol.FunctionCall: - return true, c.handle_function(ctx, cast) - case *protocol.Query: - return true, c.handle_query(ctx, cast) - case *protocol.Terminate: - return false, nil - default: - } - return true, nil -} - -func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error { - // TODO extract query and do stuff based on it - _, err := q.Write(c.server.wr) - if err != nil { - return err - } - for { - var rsp protocol.Packet - rsp, err = protocol.ReadBackend(c.server.r) - if err != nil { - return err - } - switch r := rsp.(type) { - case *protocol.ReadyForQuery: - if r.Fields.Status == 'I' { - _, err = r.Write(c.wr) - if err != nil { - return err - } - return nil - } - case *protocol.CopyInResponse, *protocol.CopyOutResponse, *protocol.CopyBothResponse: - err = c.handle_copy(ctx, rsp) - if err != nil { - return err - } - continue - } - _, err = rsp.Write(c.wr) - if err != nil { - return err - } - } -} - -func (c *Client) handle_function(ctx context.Context, f *protocol.FunctionCall) error { - _, err := f.Write(c.wr) - if err != nil { - return err - } - for { - var rsp protocol.Packet - rsp, err = protocol.ReadBackend(c.server.r) - if err != nil { - return err - } - _, err = rsp.Write(c.wr) - if err != nil { - return err - } - if r, ok := rsp.(*protocol.ReadyForQuery); ok { - if r.Fields.Status == 'I' { - break - } - } - } - - return nil -} - -func (c *Client) handle_copy(ctx context.Context, p protocol.Packet) error { - _, err := p.Write(c.wr) - if err != nil { - return err - } - switch p.(type) { - case *protocol.CopyInResponse: - outer: - for { - var rsp protocol.Packet - rsp, err = protocol.ReadFrontend(c.r) - if err != nil { - return err - } - // forward packet - _, err = rsp.Write(c.server.wr) - if err != nil { - return err - } - - switch rsp.(type) { - case *protocol.CopyDone, *protocol.CopyFail: - break outer - } - } - return nil - case *protocol.CopyOutResponse: - for { - var rsp protocol.Packet - rsp, err = protocol.ReadBackend(c.server.r) - if err != nil { - return err - } - // forward packet - _, err = rsp.Write(c.wr) - if err != nil { - return err - } - - switch r := rsp.(type) { - case *protocol.CopyDone: - return nil - case *protocol.ErrorResponse: - e := new(PostgresError) - e.Read(r) - return e - } - } - case *protocol.CopyBothResponse: - // TODO fix this filthy hack, instead of going in parallel (like normal), read fields serially - err = c.handle_copy(ctx, new(protocol.CopyInResponse)) - if err != nil { - return err - } - err = c.handle_copy(ctx, new(protocol.CopyOutResponse)) - if err != nil { - return err - } - return nil - default: - panic("unreachable") - } -} - -func todo() { - // - // /// Handle cancel request. - // pub async fn cancel( - // read: S, - // write: T, - // addr: std::net::SocketAddr, - // mut bytes: BytesMut, // The rest of the startup message. - // client_server_map: ClientServerMap, - // shutdown: Receiver<()>, - // ) -> Result<Client<S, T>, Error> { - // let process_id = bytes.get_i32(); - // let secret_key = bytes.get_i32(); - // return Ok(Client { - // read: BufReader::new(read), - // write: write, - // addr, - // buffer: BytesMut::with_capacity(8196), - // cancel_mode: true, - // transaction_mode: false, - // process_id, - // secret_key, - // client_server_map, - // parameters: HashMap::new(), - // stats: get_reporter(), - // admin: false, - // last_address_id: None, - // last_server_id: None, - // pool_name: String::from("undefined"), - // username: String::from("undefined"), - // shutdown, - // connected_to_server: false, - // }); - // } - // - // /// Handle a connected and authenticated client. - // pub async fn handle(&mut self) -> Result<(), Error> { - // // The client wants to cancel a query it has issued previously. - // if self.cancel_mode { - // trace!("Sending CancelRequest"); - // - // let (process_id, secret_key, address, port) = { - // let guard = self.client_server_map.lock(); - // - // match guard.get(&(self.process_id, self.secret_key)) { - // // Drop the mutex as soon as possible. - // // We found the server the client is using for its query - // // that it wants to cancel. - // Some((process_id, secret_key, address, port)) => ( - // process_id.clone(), - // secret_key.clone(), - // address.clone(), - // *port, - // ), - // - // // The client doesn't know / got the wrong server, - // // we're closing the connection for security reasons. - // None => return Ok(()), - // } - // }; - // - // // Opens a new separate connection to the server, sends the backend_id - // // and secret_key and then closes it for security reasons. No other interactions - // // take place. - // return Ok(Server::cancel(&address, port, process_id, secret_key).await?); - // } - // - // // The query router determines where the query is going to go, - // // e.g. primary, replica, which shard. - // let mut query_router = QueryRouter::new(); - // - // // Our custom protocol loop. - // // We expect the client to either start a transaction with regular queries - // // or issue commands for our sharding and server selection protocol. - // loop { - // trace!( - // "Client idle, waiting for message, transaction mode: {}", - // self.transaction_mode - // ); - // - // // Read a complete message from the client, which normally would be - // // either a `Q` (query) or `P` (prepare, extended protocol). - // // We can parse it here before grabbing a server from the pool, - // // in case the client is sending some custom protocol messages, e.g. - // // SET SHARDING KEY TO 'bigint'; - // - // let mut message = tokio::select! { - // _ = self.shutdown.recv() => { - // if !self.admin { - // error_response_terminal( - // &mut self.write, - // &format!("terminating connection due to administrator command") - // ).await?; - // return Ok(()) - // } - // - // // Admin clients ignore shutdown. - // else { - // read_message(&mut self.read).await? - // } - // }, - // message_result = read_message(&mut self.read) => message_result? - // }; - // - // // Avoid taking a server if the client just wants to disconnect. - // if message[0] as char == 'X' { - // debug!("Client disconnecting"); - // return Ok(()); - // } - // - // // Handle admin database queries. - // if self.admin { - // debug!("Handling admin command"); - // handle_admin(&mut self.write, message, self.client_server_map.clone()).await?; - // continue; - // } - // - // // Get a pool instance referenced by the most up-to-date - // // pointer. This ensures we always read the latest config - // // when starting a query. - // let pool = match get_pool(self.pool_name.clone(), self.username.clone()) { - // Some(pool) => pool, - // None => { - // error_response( - // &mut self.write, - // &format!( - // "No pool configured for database: {:?}, user: {:?}", - // self.pool_name, self.username - // ), - // ) - // .await?; - // return Err(Error::ClientError); - // } - // }; - // query_router.update_pool_settings(pool.settings.clone()); - // let current_shard = query_router.shard(); - // - // // Handle all custom protocol commands, if any. - // match query_router.try_execute_command(message.clone()) { - // // Normal query, not a custom command. - // None => { - // if query_router.query_parser_enabled() { - // query_router.infer_role(message.clone()); - // } - // } - // - // // SET SHARD TO - // Some((Command::SetShard, _)) => { - // // Selected shard is not configured. - // if query_router.shard() >= pool.shards() { - // // Set the shard back to what it was. - // query_router.set_shard(current_shard); - // - // error_response( - // &mut self.write, - // &format!( - // "shard {} is more than configured {}, staying on shard {}", - // query_router.shard(), - // pool.shards(), - // current_shard, - // ), - // ) - // .await?; - // } else { - // custom_protocol_response_ok(&mut self.write, "SET SHARD").await?; - // } - // continue; - // } - // - // // SET PRIMARY READS TO - // Some((Command::SetPrimaryReads, _)) => { - // custom_protocol_response_ok(&mut self.write, "SET PRIMARY READS").await?; - // continue; - // } - // - // // SET SHARDING KEY TO - // Some((Command::SetShardingKey, _)) => { - // custom_protocol_response_ok(&mut self.write, "SET SHARDING KEY").await?; - // continue; - // } - // - // // SET SERVER ROLE TO - // Some((Command::SetServerRole, _)) => { - // custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?; - // continue; - // } - // - // // SHOW SERVER ROLE - // Some((Command::ShowServerRole, value)) => { - // show_response(&mut self.write, "server role", &value).await?; - // continue; - // } - // - // // SHOW SHARD - // Some((Command::ShowShard, value)) => { - // show_response(&mut self.write, "shard", &value).await?; - // continue; - // } - // - // // SHOW PRIMARY READS - // Some((Command::ShowPrimaryReads, value)) => { - // show_response(&mut self.write, "primary reads", &value).await?; - // continue; - // } - // }; - // - // debug!("Waiting for connection from pool"); - // - // // Grab a server from the pool. - // let connection = match pool - // .get(query_router.shard(), query_router.role(), self.process_id) - // .await - // { - // Ok(conn) => { - // debug!("Got connection from pool"); - // conn - // } - // Err(err) => { - // // Clients do not expect to get SystemError followed by ReadyForQuery in the middle - // // of extended protocol submission. So we will hold off on sending the actual error - // // message to the client until we get 'S' message - // match message[0] as char { - // 'P' | 'B' | 'E' | 'D' => (), - // _ => { - // error_response( - // &mut self.write, - // "could not get connection from the pool", - // ) - // .await?; - // } - // }; - // - // error!("Could not get connection from pool: {:?}", err); - // - // continue; - // } - // }; - // - // let mut reference = connection.0; - // let address = connection.1; - // let server = &mut *reference; - // - // // Server is assigned to the client in case the client wants to - // // cancel a query later. - // server.claim(self.process_id, self.secret_key); - // self.connected_to_server = true; - // - // // Update statistics. - // if let Some(last_address_id) = self.last_address_id { - // self.stats - // .client_disconnecting(self.process_id, last_address_id); - // } - // self.stats.client_active(self.process_id, address.id); - // - // self.last_address_id = Some(address.id); - // self.last_server_id = Some(server.process_id()); - // - // debug!( - // "Client {:?} talking to server {:?}", - // self.addr, - // server.address() - // ); - // - // // Set application_name if any. - // // TODO: investigate other parameters and set them too. - // if self.parameters.contains_key("application_name") { - // server - // .set_name(&self.parameters["application_name"]) - // .await?; - // } - // - // // Transaction loop. Multiple queries can be issued by the client here. - // // The connection belongs to the client until the transaction is over, - // // or until the client disconnects if we are in session mode. - // // - // // If the client is in session mode, no more custom protocol - // // commands will be accepted. - // loop { - // let mut message = if message.len() == 0 { - // trace!("Waiting for message inside transaction or in session mode"); - // - // match read_message(&mut self.read).await { - // Ok(message) => message, - // Err(err) => { - // // Client disconnected inside a transaction. - // // Clean up the server and re-use it. - // // This prevents connection thrashing by bad clients. - // if server.in_transaction() { - // server.query("ROLLBACK").await?; - // server.query("DISCARD ALL").await?; - // server.set_name("pgcat").await?; - // } - // - // return Err(err); - // } - // } - // } else { - // let msg = message.clone(); - // message.clear(); - // msg - // }; - // - // // The message will be forwarded to the server intact. We still would like to - // // parse it below to figure out what to do with it. - // let original = message.clone(); - // - // let code = message.get_u8() as char; - // let _len = message.get_i32() as usize; - // - // trace!("Message: {}", code); - // - // match code { - // // ReadyForQuery - // 'Q' => { - // debug!("Sending query to server"); - // - // self.send_and_receive_loop(code, original, server, &address, &pool) - // .await?; - // - // if !server.in_transaction() { - // // Report transaction executed statistics. - // self.stats.transaction(self.process_id, address.id); - // - // // Release server back to the pool if we are in transaction mode. - // // If we are in session mode, we keep the server until the client disconnects. - // if self.transaction_mode { - // break; - // } - // } - // } - // - // // Terminate - // 'X' => { - // // Client closing. Rollback and clean up - // // connection before releasing into the pool. - // // Pgbouncer closes the connection which leads to - // // connection thrashing when clients misbehave. - // if server.in_transaction() { - // server.query("ROLLBACK").await?; - // server.query("DISCARD ALL").await?; - // server.set_name("pgcat").await?; - // } - // - // self.release(); - // - // return Ok(()); - // } - // - // // Parse - // // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. - // 'P' => { - // self.buffer.put(&original[..]); - // } - // - // // Bind - // // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' - // 'B' => { - // self.buffer.put(&original[..]); - // } - // - // // Describe - // // Command a client can issue to describe a previously prepared named statement. - // 'D' => { - // self.buffer.put(&original[..]); - // } - // - // // Execute - // // Execute a prepared statement prepared in `P` and bound in `B`. - // 'E' => { - // self.buffer.put(&original[..]); - // } - // - // // Sync - // // Frontend (client) is asking for the query result now. - // 'S' => { - // debug!("Sending query to server"); - // - // self.buffer.put(&original[..]); - // - // self.send_and_receive_loop( - // code, - // self.buffer.clone(), - // server, - // &address, - // &pool, - // ) - // .await?; - // - // self.buffer.clear(); - // - // if !server.in_transaction() { - // self.stats.transaction(self.process_id, address.id); - // - // // Release server back to the pool if we are in transaction mode. - // // If we are in session mode, we keep the server until the client disconnects. - // if self.transaction_mode { - // break; - // } - // } - // } - // - // // CopyData - // 'd' => { - // // Forward the data to the server, - // // don't buffer it since it can be rather large. - // self.send_server_message(server, original, &address, &pool) - // .await?; - // } - // - // // CopyDone or CopyFail - // // Copy is done, successfully or not. - // 'c' | 'f' => { - // self.send_server_message(server, original, &address, &pool) - // .await?; - // - // let response = self.receive_server_message(server, &address, &pool).await?; - // - // match write_all_half(&mut self.write, response).await { - // Ok(_) => (), - // Err(err) => { - // server.mark_bad(); - // return Err(err); - // } - // }; - // - // if !server.in_transaction() { - // self.stats.transaction(self.process_id, address.id); - // - // // Release server back to the pool if we are in transaction mode. - // // If we are in session mode, we keep the server until the client disconnects. - // if self.transaction_mode { - // break; - // } - // } - // } - // - // // Some unexpected message. We either did not implement the protocol correctly - // // or this is not a Postgres client we're talking to. - // _ => { - // error!("Unexpected code: {}", code); - // } - // } - // } - // - // // The server is no longer bound to us, we can't cancel it's queries anymore. - // debug!("Releasing server back into the pool"); - // self.stats.server_idle(server.process_id(), address.id); - // self.connected_to_server = false; - // self.release(); - // self.stats.client_idle(self.process_id, address.id); - // } - // } - // - // /// Release the server from the client: it can't cancel its queries anymore. - // pub fn release(&self) { - // let mut guard = self.client_server_map.lock(); - // guard.remove(&(self.process_id, self.secret_key)); - // } - // - // async fn send_and_receive_loop( - // &mut self, - // code: char, - // message: BytesMut, - // server: &mut Server, - // address: &Address, - // pool: &ConnectionPool, - // ) -> Result<(), Error> { - // debug!("Sending {} to server", code); - // - // self.send_server_message(server, message, &address, &pool) - // .await?; - // - // // Read all data the server has to offer, which can be multiple messages - // // buffered in 8196 bytes chunks. - // loop { - // let response = self.receive_server_message(server, &address, &pool).await?; - // - // match write_all_half(&mut self.write, response).await { - // Ok(_) => (), - // Err(err) => { - // server.mark_bad(); - // return Err(err); - // } - // }; - // - // if !server.is_data_available() { - // break; - // } - // } - // - // // Report query executed statistics. - // self.stats.query(self.process_id, address.id); - // - // Ok(()) - // } - // - // async fn send_server_message( - // &self, - // server: &mut Server, - // message: BytesMut, - // address: &Address, - // pool: &ConnectionPool, - // ) -> Result<(), Error> { - // match server.send(message).await { - // Ok(_) => Ok(()), - // Err(err) => { - // pool.ban(address, self.process_id); - // Err(err) - // } - // } - // } - // - // async fn receive_server_message( - // &mut self, - // server: &mut Server, - // address: &Address, - // pool: &ConnectionPool, - // ) -> Result<BytesMut, Error> { - // if pool.settings.user.statement_timeout > 0 { - // match tokio::time::timeout( - // tokio::time::Duration::from_millis(pool.settings.user.statement_timeout), - // server.recv(), - // ) - // .await - // { - // Ok(result) => match result { - // Ok(message) => Ok(message), - // Err(err) => { - // pool.ban(address, self.process_id); - // error_response_terminal( - // &mut self.write, - // &format!("error receiving data from server: {:?}", err), - // ) - // .await?; - // Err(err) - // } - // }, - // Err(_) => { - // error!( - // "Statement timeout while talking to {:?} with user {}", - // address, pool.settings.user.username - // ); - // server.mark_bad(); - // pool.ban(address, self.process_id); - // error_response_terminal(&mut self.write, "pool statement timeout").await?; - // Err(Error::StatementTimeout) - // } - // } - // } else { - // match server.recv().await { - // Ok(message) => Ok(message), - // Err(err) => { - // pool.ban(address, self.process_id); - // error_response_terminal( - // &mut self.write, - // &format!("error receiving data from server: {:?}", err), - // ) - // .await?; - // Err(err) - // } - // } - // } - // } - //} - // - //impl<S, T> Drop for Client<S, T> { - // fn drop(&mut self) { - // let mut guard = self.client_server_map.lock(); - // guard.remove(&(self.process_id, self.secret_key)); - // - // // Dirty shutdown - // // TODO: refactor, this is not the best way to handle state management. - // if let Some(address_id) = self.last_address_id { - // self.stats.client_disconnecting(self.process_id, address_id); - // - // if self.connected_to_server { - // if let Some(process_id) = self.last_server_id { - // self.stats.server_idle(process_id, address_id); - // } - // } - // } - // } - //} - +type Client interface { + Send(pkt protocol.Packet) error + Recv() (protocol.Packet, error) } diff --git a/lib/gat/conn_pool.go b/lib/gat/conn_pool.go new file mode 100644 index 00000000..5babf921 --- /dev/null +++ b/lib/gat/conn_pool.go @@ -0,0 +1,13 @@ +package gat + +import ( + "context" + "gfx.cafe/gfx/pggat/lib/config" + "gfx.cafe/gfx/pggat/lib/gat/protocol" +) + +type ConnectionPool interface { + GetUser() *config.User + GetServerInfo() []*protocol.ParameterStatus + Query(ctx context.Context, query string) (<-chan protocol.Packet, error) +} diff --git a/lib/gat/gat.go b/lib/gat/gat.go new file mode 100644 index 00000000..d820c773 --- /dev/null +++ b/lib/gat/gat.go @@ -0,0 +1,5 @@ +package gat + +type Gat interface { + GetPool(name string) (Pool, error) +} diff --git a/lib/gat/gatling.go b/lib/gat/gatling.go deleted file mode 100644 index 5f8d637e..00000000 --- a/lib/gat/gatling.go +++ /dev/null @@ -1,202 +0,0 @@ -package gat - -import ( - "context" - "fmt" - "net" - "sync" - - "git.tuxpa.in/a/zlog/log" - - "gfx.cafe/gfx/pggat/lib/config" - "gfx.cafe/gfx/pggat/lib/gat/protocol" -) - -type Gatling struct { - c *config.Global - mu sync.RWMutex - - rout *QueryRouter - - csm map[ClientKey]*ClientInfo - clients map[string]*Client - - chConfig chan *config.Global - - servers map[string]*Server - pools map[string]*ConnectionPool -} - -func NewGatling() *Gatling { - g := &Gatling{ - csm: map[ClientKey]*ClientInfo{}, - chConfig: make(chan *config.Global, 1), - servers: map[string]*Server{}, - clients: map[string]*Client{}, - pools: map[string]*ConnectionPool{}, - rout: &QueryRouter{}, - } - go g.watchConfigs() - return g -} - -func (g *Gatling) watchConfigs() { - for { - c := <-g.chConfig - err := g.ensureConfig(c) - if err != nil { - log.Println("failed to parse config", err) - } - } -} - -func (g *Gatling) GetClient(s string) (*Client, error) { - g.mu.RLock() - defer g.mu.RUnlock() - srv, ok := g.clients[s] - if !ok { - return nil, fmt.Errorf("client '%s' not found", s) - } - return srv, nil -} -func (g *Gatling) GetPool(s string) (*ConnectionPool, error) { - g.mu.RLock() - defer g.mu.RUnlock() - srv, ok := g.pools[s] - if !ok { - return nil, fmt.Errorf("pool '%s' not found", s) - } - return srv, nil -} - -func (g *Gatling) GetServer(s string) (*Server, error) { - g.mu.RLock() - defer g.mu.RUnlock() - srv, ok := g.servers[s] - if !ok { - return nil, fmt.Errorf("server '%s' not found", s) - } - return srv, nil -} - -func (g *Gatling) ensureConfig(c *config.Global) error { - g.mu.Lock() - defer g.mu.Unlock() - - if err := g.ensureGeneral(c); err != nil { - return err - } - if err := g.ensureAdmin(c); err != nil { - return err - } - if err := g.ensureServers(c); err != nil { - return err - } - if err := g.ensurePools(c); err != nil { - return err - } - - return nil -} - -// TODO: all other settings -func (g *Gatling) ensureGeneral(c *config.Global) error { - return nil -} - -// TODO: should configure the admin things, metrics, etc -func (g *Gatling) ensureAdmin(c *config.Global) error { - return nil -} - -// TODO: should connect to and load servers from config -func (g *Gatling) ensureServers(c *config.Global) error { - return nil -} - -// TODO: should connect to & load pools from config -func (g *Gatling) ensurePools(c *config.Global) error { - return nil -} - -func (g *Gatling) ListenAndServe(ctx context.Context) error { - ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", g.c.General.Host, g.c.General.Port)) - if err != nil { - return err - } - for { - var c net.Conn - c, err = ln.Accept() - if err != nil { - return err - } - go func() { - err := g.handleConnection(ctx, c) - if err != nil { - log.Println("disconnected:", err) - } - }() - } -} - -// TODO: TLS -func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { - cl := NewClient(g.c, c, false) - err := cl.Accept(ctx) - if err != nil { - log.Println(err.Error()) - switch e := err.(type) { - case *PostgresError: - _, err = e.Packet().Write(cl.wr) - return err - default: - pgErr := &PostgresError{ - Severity: Error, - Code: InternalError, - Message: e.Error(), - } - _, err = pgErr.Packet().Write(cl.wr) - return err - } - } - return nil -} - -type QueryRequest struct { - ctx context.Context - raw protocol.Packet - c *Client -} - -func (g *Gatling) handleQuery(ctx context.Context, c *Client, raw protocol.Packet) error { - // 1. analyze query using the query router - role, err := g.rout.InferRole(raw) - if err != nil { - return err - } - pool, err := g.GetPool(g.selectPool(c, role)) - if err != nil { - return err - } - // check config, select a pool - _ = pool - // TODO: we need to add some more information to the connectionpools, like current load, selectors, etc - // perhaps we should just put the server connections in ServerPool and make that responsible for all of that - srv, err := g.GetServer("some_output") - if err != nil { - return err - } - // write the packet or maybe send in a channel to the server - _ = srv - - // send the result back to the client - _ = c - return nil -} - -func (g *Gatling) selectPool(c *Client, role config.ServerRole) string { - g.mu.RLock() - defer g.mu.RUnlock() - // do some filtering and figure out which pool you want to connect this client to, knowing their rold - return "some_pool" -} diff --git a/lib/gat/admin.go b/lib/gat/gatling/admin/admin.go similarity index 96% rename from lib/gat/admin.go rename to lib/gat/gatling/admin/admin.go index b98aa4b3..fb5103b6 100644 --- a/lib/gat/admin.go +++ b/lib/gat/gatling/admin/admin.go @@ -1,4 +1,4 @@ -package gat +package admin import ( "gfx.cafe/gfx/pggat/lib/gat/protocol" @@ -52,14 +52,14 @@ func AdminServerInfo() []*protocol.ParameterStatus { // stream: &mut T, // mut query: BytesMut, // client_server_map: ClientServerMap, -//) -> Result<(), Error> +//) -> Result<(), Err> //where // T: tokio::io::AsyncWrite + std::marker::Unpin, //{ // let code = query.get_u8() as char; // // if code != 'Q' { -// return Err(Error::ProtocolSyncError); +// return Err(Err::ProtocolSyncError); // } // // let len = query.get_i32() as usize; @@ -112,7 +112,7 @@ func AdminServerInfo() []*protocol.ParameterStatus { //} // ///// Column-oriented statistics. -//async fn show_lists<T>(stream: &mut T) -> Result<(), Error> +//async fn show_lists<T>(stream: &mut T) -> Result<(), Err> //where // T: tokio::io::AsyncWrite + std::marker::Unpin, //{ @@ -185,7 +185,7 @@ func AdminServerInfo() []*protocol.ParameterStatus { //} // ///// Show PgCat version. -//async fn show_version<T>(stream: &mut T) -> Result<(), Error> +//async fn show_version<T>(stream: &mut T) -> Result<(), Err> //where // T: tokio::io::AsyncWrite + std::marker::Unpin, //{ @@ -203,7 +203,7 @@ func AdminServerInfo() []*protocol.ParameterStatus { //} // ///// Show utilization of connection pools for each shard and replicas. -//async fn show_pools<T>(stream: &mut T) -> Result<(), Error> +//async fn show_pools<T>(stream: &mut T) -> Result<(), Err> //where // T: tokio::io::AsyncWrite + std::marker::Unpin, //{ @@ -262,7 +262,7 @@ func AdminServerInfo() []*protocol.ParameterStatus { //} // ///// Show shards and replicas. -//async fn show_databases<T>(stream: &mut T) -> Result<(), Error> +//async fn show_databases<T>(stream: &mut T) -> Result<(), Err> //where // T: tokio::io::AsyncWrite + std::marker::Unpin, //{ @@ -330,7 +330,7 @@ func AdminServerInfo() []*protocol.ParameterStatus { // ///// Ignore any SET commands the client sends. ///// This is common initialization done by ORMs. -//async fn ignore_set<T>(stream: &mut T) -> Result<(), Error> +//async fn ignore_set<T>(stream: &mut T) -> Result<(), Err> //where // T: tokio::io::AsyncWrite + std::marker::Unpin, //{ @@ -338,7 +338,7 @@ func AdminServerInfo() []*protocol.ParameterStatus { //} // ///// Reload the configuration file without restarting the process. -//async fn reload<T>(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error> +//async fn reload<T>(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Err> //where // T: tokio::io::AsyncWrite + std::marker::Unpin, //{ @@ -361,7 +361,7 @@ func AdminServerInfo() []*protocol.ParameterStatus { //} // ///// Shows current configuration. -//async fn show_config<T>(stream: &mut T) -> Result<(), Error> +//async fn show_config<T>(stream: &mut T) -> Result<(), Err> //where // T: tokio::io::AsyncWrite + std::marker::Unpin, //{ @@ -407,7 +407,7 @@ func AdminServerInfo() []*protocol.ParameterStatus { //} // ///// Show shard and replicas statistics. -//async fn show_stats<T>(stream: &mut T) -> Result<(), Error> +//async fn show_stats<T>(stream: &mut T) -> Result<(), Err> //where // T: tokio::io::AsyncWrite + std::marker::Unpin, //{ diff --git a/lib/gat/admin_test.go b/lib/gat/gatling/admin/admin_test.go similarity index 81% rename from lib/gat/admin_test.go rename to lib/gat/gatling/admin/admin_test.go index 605458c8..96cc0a09 100644 --- a/lib/gat/admin_test.go +++ b/lib/gat/gatling/admin/admin_test.go @@ -1,3 +1,3 @@ -package gat +package admin // TODO: no tests in original package. we shoul write our oen diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go new file mode 100644 index 00000000..2621cbe7 --- /dev/null +++ b/lib/gat/gatling/client/client.go @@ -0,0 +1,1052 @@ +package client + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "crypto/tls" + "fmt" + "gfx.cafe/gfx/pggat/lib/config" + "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/gatling/messages" + "gfx.cafe/gfx/pggat/lib/gat/protocol/pg_error" + "io" + "math/big" + "net" + "reflect" + + "gfx.cafe/gfx/pggat/lib/gat/protocol" + "git.tuxpa.in/a/zlog" + "git.tuxpa.in/a/zlog/log" + "github.com/ethereum/go-ethereum/common/math" +) + +type ClientKey [2]int + +type ClientInfo struct { + A int + B int + C string + D uint16 +} + +// / client state, one per client +type Client struct { + conn net.Conn + r *bufio.Reader + wr io.Writer + + buf bytes.Buffer + + addr net.Addr + + cancel_mode bool + txn_mode bool + + pid int32 + secret_key int32 + + parameters map[string]string + stats any // TODO: Reporter + admin bool + + server gat.ConnectionPool + + last_addr_id int + last_srv_id int + + pool_name string + username string + + gatling gat.Gat + conf *config.Global + + log zlog.Logger +} + +func NewClient( + gatling gat.Gat, + conf *config.Global, + conn net.Conn, + admin_only bool, +) *Client { + c := &Client{ + conn: conn, + r: bufio.NewReader(conn), + wr: conn, + addr: conn.RemoteAddr(), + gatling: gatling, + conf: conf, + } + c.log = log.With(). + Stringer("clientaddr", c.addr).Logger() + return c +} + +func (c *Client) Accept(ctx context.Context) error { + // read a packet + startup := new(protocol.StartupMessage) + err := startup.Read(c.r) + if err != nil { + return err + } + switch startup.Fields.ProtocolVersionNumber { + case 196608: + case 80877102: + return c.handle_cancel(ctx, startup) + case 80877103: + // ssl stuff now + useSsl := (c.conf.General.TlsCertificate != "") + if !useSsl { + _, err = protocol.WriteByte(c.wr, 'N') + if err != nil { + return err + } + startup = new(protocol.StartupMessage) + err = startup.Read(c.r) + if err != nil { + return err + } + } else { + _, err = protocol.WriteByte(c.wr, 'S') + if err != nil { + return err + } + //TODO: we need to do an ssl handshake here. + var cert tls.Certificate + cert, err = tls.LoadX509KeyPair(c.conf.General.TlsCertificate, c.conf.General.TlsPrivateKey) + if err != nil { + return err + } + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + InsecureSkipVerify: true, + } + c.conn = tls.Server(c.conn, cfg) + c.r = bufio.NewReader(c.conn) + c.wr = c.conn + err = startup.Read(c.r) + if err != nil { + return err + } + } + } + params := make(map[string]string) + for _, v := range startup.Fields.Parameters { + params[v.Name] = v.Value + } + + var ok bool + c.pool_name, ok = params["database"] + if !ok { + return &pg_error.Error{ + Severity: pg_error.Fatal, + Code: pg_error.InvalidAuthorizationSpecification, + Message: "param database required", + } + } + + c.username, ok = params["user"] + if !ok { + return &pg_error.Error{ + Severity: pg_error.Fatal, + Code: pg_error.InvalidAuthorizationSpecification, + Message: "param user required", + } + } + + c.admin = (c.pool_name == "pgcat" || c.pool_name == "pgbouncer") + + if c.conf.General.AdminOnly && !c.admin { + c.log.Debug().Msg("rejected non admin, since admin only mode") + return &pg_error.Error{ + Severity: pg_error.Fatal, + Code: pg_error.InvalidAuthorizationSpecification, + Message: "rejected non admin", + } + } + + pid, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt32)) + if err != nil { + return err + } + c.pid = int32(pid.Int64()) + skey, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt32)) + if err != nil { + return err + } + + c.secret_key = int32(skey.Int64()) + // TODO: Add SASL support. + + // Perform MD5 authentication. + pkt, salt, err := messages.CreateMd5Challenge() + if err != nil { + return err + } + _, err = pkt.Write(c.wr) + if err != nil { + return err + } + + var rsp protocol.Packet + rsp, err = protocol.ReadFrontend(c.r) + if err != nil { + return err + } + var passwordResponse []byte + switch r := rsp.(type) { + case *protocol.AuthenticationResponse: + passwordResponse = r.Fields.Data + default: + return &pg_error.Error{ + Severity: pg_error.Fatal, + Code: pg_error.InvalidAuthorizationSpecification, + Message: fmt.Sprintf("wanted AuthenticationResponse packet, got '%+v'", rsp), + } + } + + var pool gat.Pool + pool, err = c.gatling.GetPool(c.pool_name) + if err != nil { + return err + } + + // get user + var user *config.User + user, err = pool.GetUser(c.username) + if err != nil { + return err + } + + // Authenticate admin user. + if c.admin { + pw_hash := messages.Md5HashPassword(c.conf.General.AdminUsername, c.conf.General.AdminPassword, salt[:]) + if !reflect.DeepEqual(pw_hash, passwordResponse) { + return &pg_error.Error{ + Severity: pg_error.Fatal, + Code: pg_error.InvalidPassword, + Message: "invalid password", + } + } + } else { + pw_hash := messages.Md5HashPassword(c.username, user.Password, salt[:]) + if !reflect.DeepEqual(pw_hash, passwordResponse) { + return &pg_error.Error{ + Severity: pg_error.Fatal, + Code: pg_error.InvalidPassword, + Message: "invalid password", + } + } + } + + c.server, err = pool.WithUser(c.username) + if err != nil { + return err + } + + c.log.Debug().Msg("Password authentication successful") + authOk := new(protocol.Authentication) + authOk.Fields.Code = 0 + _, err = authOk.Write(c.wr) + if err != nil { + return err + } + + // + info := c.server.GetServerInfo() + for _, inf := range info { + _, err = inf.Write(c.wr) + if err != nil { + return err + } + } + backendKeyData := new(protocol.BackendKeyData) + backendKeyData.Fields.ProcessID = c.pid + backendKeyData.Fields.SecretKey = c.secret_key + _, err = backendKeyData.Write(c.wr) + if err != nil { + return err + } + readyForQuery := new(protocol.ReadyForQuery) + readyForQuery.Fields.Status = byte('I') + _, err = readyForQuery.Write(c.wr) + if err != nil { + return err + } + c.log.Debug().Msg("Ready for Query") + open := true + for open { + open, err = c.tick(ctx) + if err != nil { + return err + } + } + return nil +} + +// TODO: we need to keep track of queries so we can handle cancels +func (c *Client) handle_cancel(ctx context.Context, p *protocol.StartupMessage) error { + log.Println("cancel msg", p) + return nil +} + +// reads a packet from stream and handles it +func (c *Client) tick(ctx context.Context) (bool, error) { + rsp, err := c.Recv() + if err != nil { + return true, err + } + switch cast := rsp.(type) { + case *protocol.Query: + return true, c.handle_query(ctx, cast) + case *protocol.Terminate: + return false, nil + default: + } + return true, nil +} + +func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error { + rep, err := c.server.Query(ctx, q.Fields.Query) + if err != nil { + return err + } + for { + rsp := <-rep + if rsp == nil { + break + } + err = c.Send(rsp) + if err != nil { + return err + } + } + return nil +} + +/* +func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error { + // TODO extract query and do stuff based on it + _, err := q.Write(c.server.wr) + if err != nil { + return err + } + for { + var rsp protocol.Packet + rsp, err = protocol.ReadBackend(c.server.r) + if err != nil { + return err + } + switch r := rsp.(type) { + case *protocol.ReadyForQuery: + if r.Fields.Status == 'I' { + _, err = r.Write(c.wr) + if err != nil { + return err + } + return nil + } + case *protocol.CopyInResponse, *protocol.CopyOutResponse, *protocol.CopyBothResponse: + err = c.handle_copy(ctx, rsp) + if err != nil { + return err + } + continue + } + _, err = rsp.Write(c.wr) + if err != nil { + return err + } + } +} + +func (c *Client) handle_function(ctx context.Context, f *protocol.FunctionCall) error { + _, err := f.Write(c.wr) + if err != nil { + return err + } + for { + var rsp protocol.Packet + rsp, err = protocol.ReadBackend(c.server.r) + if err != nil { + return err + } + _, err = rsp.Write(c.wr) + if err != nil { + return err + } + if r, ok := rsp.(*protocol.ReadyForQuery); ok { + if r.Fields.Status == 'I' { + break + } + } + } + + return nil +} + +func (c *Client) handle_copy(ctx context.Context, p protocol.Packet) error { + _, err := p.Write(c.wr) + if err != nil { + return err + } + switch p.(type) { + case *protocol.CopyInResponse: + outer: + for { + var rsp protocol.Packet + rsp, err = protocol.ReadFrontend(c.r) + if err != nil { + return err + } + // forward packet + _, err = rsp.Write(c.server.wr) + if err != nil { + return err + } + + switch rsp.(type) { + case *protocol.CopyDone, *protocol.CopyFail: + break outer + } + } + return nil + case *protocol.CopyOutResponse: + for { + var rsp protocol.Packet + rsp, err = protocol.ReadBackend(c.server.r) + if err != nil { + return err + } + // forward packet + _, err = rsp.Write(c.wr) + if err != nil { + return err + } + + switch r := rsp.(type) { + case *protocol.CopyDone: + return nil + case *protocol.ErrorResponse: + e := new(error2.Error) + e.Read(r) + return e + } + } + case *protocol.CopyBothResponse: + // TODO fix this filthy hack, instead of going in parallel (like normal), read fields serially + err = c.handle_copy(ctx, new(protocol.CopyInResponse)) + if err != nil { + return err + } + err = c.handle_copy(ctx, new(protocol.CopyOutResponse)) + if err != nil { + return err + } + return nil + default: + panic("unreachable") + } +} + +*/ + +func (c *Client) Send(pkt protocol.Packet) error { + _, err := pkt.Write(c.wr) + return err +} + +func (c *Client) Recv() (protocol.Packet, error) { + pkt, err := protocol.ReadFrontend(c.r) + return pkt, err +} + +var _ gat.Client = (*Client)(nil) + +func todo() { + // + // /// Handle cancel request. + // pub async fn cancel( + // read: S, + // write: T, + // addr: std::net::SocketAddr, + // mut bytes: BytesMut, // The rest of the startup message. + // client_server_map: ClientServerMap, + // shutdown: Receiver<()>, + // ) -> Result<Client<S, T>, Err> { + // let process_id = bytes.get_i32(); + // let secret_key = bytes.get_i32(); + // return Ok(Client { + // read: BufReader::new(read), + // write: write, + // addr, + // buffer: BytesMut::with_capacity(8196), + // cancel_mode: true, + // transaction_mode: false, + // process_id, + // secret_key, + // client_server_map, + // parameters: HashMap::new(), + // stats: get_reporter(), + // admin: false, + // last_address_id: None, + // last_server_id: None, + // pool_name: String::from("undefined"), + // username: String::from("undefined"), + // shutdown, + // connected_to_server: false, + // }); + // } + // + // /// Handle a connected and authenticated client. + // pub async fn handle(&mut self) -> Result<(), Err> { + // // The client wants to cancel a query it has issued previously. + // if self.cancel_mode { + // trace!("Sending CancelRequest"); + // + // let (process_id, secret_key, address, port) = { + // let guard = self.client_server_map.lock(); + // + // match guard.get(&(self.process_id, self.secret_key)) { + // // Drop the mutex as soon as possible. + // // We found the server the client is using for its query + // // that it wants to cancel. + // Some((process_id, secret_key, address, port)) => ( + // process_id.clone(), + // secret_key.clone(), + // address.clone(), + // *port, + // ), + // + // // The client doesn't know / got the wrong server, + // // we're closing the connection for security reasons. + // None => return Ok(()), + // } + // }; + // + // // Opens a new separate connection to the server, sends the backend_id + // // and secret_key and then closes it for security reasons. No other interactions + // // take place. + // return Ok(Server::cancel(&address, port, process_id, secret_key).await?); + // } + // + // // The query router determines where the query is going to go, + // // e.g. primary, replica, which shard. + // let mut query_router = QueryRouter::new(); + // + // // Our custom protocol loop. + // // We expect the client to either start a transaction with regular queries + // // or issue commands for our sharding and server selection protocol. + // loop { + // trace!( + // "Client idle, waiting for message, transaction mode: {}", + // self.transaction_mode + // ); + // + // // Read a complete message from the client, which normally would be + // // either a `Q` (query) or `P` (prepare, extended protocol). + // // We can parse it here before grabbing a server from the pool, + // // in case the client is sending some custom protocol messages, e.g. + // // SET SHARDING KEY TO 'bigint'; + // + // let mut message = tokio::select! { + // _ = self.shutdown.recv() => { + // if !self.admin { + // error_response_terminal( + // &mut self.write, + // &format!("terminating connection due to administrator command") + // ).await?; + // return Ok(()) + // } + // + // // Admin clients ignore shutdown. + // else { + // read_message(&mut self.read).await? + // } + // }, + // message_result = read_message(&mut self.read) => message_result? + // }; + // + // // Avoid taking a server if the client just wants to disconnect. + // if message[0] as char == 'X' { + // debug!("Client disconnecting"); + // return Ok(()); + // } + // + // // Handle admin database queries. + // if self.admin { + // debug!("Handling admin command"); + // handle_admin(&mut self.write, message, self.client_server_map.clone()).await?; + // continue; + // } + // + // // Get a pool instance referenced by the most up-to-date + // // pointer. This ensures we always read the latest config + // // when starting a query. + // let pool = match get_pool(self.pool_name.clone(), self.username.clone()) { + // Some(pool) => pool, + // None => { + // error_response( + // &mut self.write, + // &format!( + // "No pool configured for database: {:?}, user: {:?}", + // self.pool_name, self.username + // ), + // ) + // .await?; + // return Err(Err::ClientError); + // } + // }; + // query_router.update_pool_settings(pool.settings.clone()); + // let current_shard = query_router.shard(); + // + // // Handle all custom protocol commands, if any. + // match query_router.try_execute_command(message.clone()) { + // // Normal query, not a custom command. + // None => { + // if query_router.query_parser_enabled() { + // query_router.infer_role(message.clone()); + // } + // } + // + // // SET SHARD TO + // Some((Command::SetShard, _)) => { + // // Selected shard is not configured. + // if query_router.shard() >= pool.shards() { + // // Set the shard back to what it was. + // query_router.set_shard(current_shard); + // + // error_response( + // &mut self.write, + // &format!( + // "shard {} is more than configured {}, staying on shard {}", + // query_router.shard(), + // pool.shards(), + // current_shard, + // ), + // ) + // .await?; + // } else { + // custom_protocol_response_ok(&mut self.write, "SET SHARD").await?; + // } + // continue; + // } + // + // // SET PRIMARY READS TO + // Some((Command::SetPrimaryReads, _)) => { + // custom_protocol_response_ok(&mut self.write, "SET PRIMARY READS").await?; + // continue; + // } + // + // // SET SHARDING KEY TO + // Some((Command::SetShardingKey, _)) => { + // custom_protocol_response_ok(&mut self.write, "SET SHARDING KEY").await?; + // continue; + // } + // + // // SET SERVER ROLE TO + // Some((Command::SetServerRole, _)) => { + // custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?; + // continue; + // } + // + // // SHOW SERVER ROLE + // Some((Command::ShowServerRole, value)) => { + // show_response(&mut self.write, "server role", &value).await?; + // continue; + // } + // + // // SHOW SHARD + // Some((Command::ShowShard, value)) => { + // show_response(&mut self.write, "shard", &value).await?; + // continue; + // } + // + // // SHOW PRIMARY READS + // Some((Command::ShowPrimaryReads, value)) => { + // show_response(&mut self.write, "primary reads", &value).await?; + // continue; + // } + // }; + // + // debug!("Waiting for connection from pool"); + // + // // Grab a server from the pool. + // let connection = match pool + // .get(query_router.shard(), query_router.role(), self.process_id) + // .await + // { + // Ok(conn) => { + // debug!("Got connection from pool"); + // conn + // } + // Err(err) => { + // // Clients do not expect to get SystemError followed by ReadyForQuery in the middle + // // of extended protocol submission. So we will hold off on sending the actual error + // // message to the client until we get 'S' message + // match message[0] as char { + // 'P' | 'B' | 'E' | 'D' => (), + // _ => { + // error_response( + // &mut self.write, + // "could not get connection from the pool", + // ) + // .await?; + // } + // }; + // + // error!("Could not get connection from pool: {:?}", err); + // + // continue; + // } + // }; + // + // let mut reference = connection.0; + // let address = connection.1; + // let server = &mut *reference; + // + // // Server is assigned to the client in case the client wants to + // // cancel a query later. + // server.claim(self.process_id, self.secret_key); + // self.connected_to_server = true; + // + // // Update statistics. + // if let Some(last_address_id) = self.last_address_id { + // self.stats + // .client_disconnecting(self.process_id, last_address_id); + // } + // self.stats.client_active(self.process_id, address.id); + // + // self.last_address_id = Some(address.id); + // self.last_server_id = Some(server.process_id()); + // + // debug!( + // "Client {:?} talking to server {:?}", + // self.addr, + // server.address() + // ); + // + // // Set application_name if any. + // // TODO: investigate other parameters and set them too. + // if self.parameters.contains_key("application_name") { + // server + // .set_name(&self.parameters["application_name"]) + // .await?; + // } + // + // // Transaction loop. Multiple queries can be issued by the client here. + // // The connection belongs to the client until the transaction is over, + // // or until the client disconnects if we are in session mode. + // // + // // If the client is in session mode, no more custom protocol + // // commands will be accepted. + // loop { + // let mut message = if message.len() == 0 { + // trace!("Waiting for message inside transaction or in session mode"); + // + // match read_message(&mut self.read).await { + // Ok(message) => message, + // Err(err) => { + // // Client disconnected inside a transaction. + // // Clean up the server and re-use it. + // // This prevents connection thrashing by bad clients. + // if server.in_transaction() { + // server.query("ROLLBACK").await?; + // server.query("DISCARD ALL").await?; + // server.set_name("pgcat").await?; + // } + // + // return Err(err); + // } + // } + // } else { + // let msg = message.clone(); + // message.clear(); + // msg + // }; + // + // // The message will be forwarded to the server intact. We still would like to + // // parse it below to figure out what to do with it. + // let original = message.clone(); + // + // let code = message.get_u8() as char; + // let _len = message.get_i32() as usize; + // + // trace!("Message: {}", code); + // + // match code { + // // ReadyForQuery + // 'Q' => { + // debug!("Sending query to server"); + // + // self.send_and_receive_loop(code, original, server, &address, &pool) + // .await?; + // + // if !server.in_transaction() { + // // Report transaction executed statistics. + // self.stats.transaction(self.process_id, address.id); + // + // // Release server back to the pool if we are in transaction mode. + // // If we are in session mode, we keep the server until the client disconnects. + // if self.transaction_mode { + // break; + // } + // } + // } + // + // // Terminate + // 'X' => { + // // Client closing. Rollback and clean up + // // connection before releasing into the pool. + // // Pgbouncer closes the connection which leads to + // // connection thrashing when clients misbehave. + // if server.in_transaction() { + // server.query("ROLLBACK").await?; + // server.query("DISCARD ALL").await?; + // server.set_name("pgcat").await?; + // } + // + // self.release(); + // + // return Ok(()); + // } + // + // // Parse + // // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. + // 'P' => { + // self.buffer.put(&original[..]); + // } + // + // // Bind + // // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' + // 'B' => { + // self.buffer.put(&original[..]); + // } + // + // // Describe + // // Command a client can issue to describe a previously prepared named statement. + // 'D' => { + // self.buffer.put(&original[..]); + // } + // + // // Execute + // // Execute a prepared statement prepared in `P` and bound in `B`. + // 'E' => { + // self.buffer.put(&original[..]); + // } + // + // // Sync + // // Frontend (client) is asking for the query result now. + // 'S' => { + // debug!("Sending query to server"); + // + // self.buffer.put(&original[..]); + // + // self.send_and_receive_loop( + // code, + // self.buffer.clone(), + // server, + // &address, + // &pool, + // ) + // .await?; + // + // self.buffer.clear(); + // + // if !server.in_transaction() { + // self.stats.transaction(self.process_id, address.id); + // + // // Release server back to the pool if we are in transaction mode. + // // If we are in session mode, we keep the server until the client disconnects. + // if self.transaction_mode { + // break; + // } + // } + // } + // + // // CopyData + // 'd' => { + // // Forward the data to the server, + // // don't buffer it since it can be rather large. + // self.send_server_message(server, original, &address, &pool) + // .await?; + // } + // + // // CopyDone or CopyFail + // // Copy is done, successfully or not. + // 'c' | 'f' => { + // self.send_server_message(server, original, &address, &pool) + // .await?; + // + // let response = self.receive_server_message(server, &address, &pool).await?; + // + // match write_all_half(&mut self.write, response).await { + // Ok(_) => (), + // Err(err) => { + // server.mark_bad(); + // return Err(err); + // } + // }; + // + // if !server.in_transaction() { + // self.stats.transaction(self.process_id, address.id); + // + // // Release server back to the pool if we are in transaction mode. + // // If we are in session mode, we keep the server until the client disconnects. + // if self.transaction_mode { + // break; + // } + // } + // } + // + // // Some unexpected message. We either did not implement the protocol correctly + // // or this is not a Postgres client we're talking to. + // _ => { + // error!("Unexpected code: {}", code); + // } + // } + // } + // + // // The server is no longer bound to us, we can't cancel it's queries anymore. + // debug!("Releasing server back into the pool"); + // self.stats.server_idle(server.process_id(), address.id); + // self.connected_to_server = false; + // self.release(); + // self.stats.client_idle(self.process_id, address.id); + // } + // } + // + // /// Release the server from the client: it can't cancel its queries anymore. + // pub fn release(&self) { + // let mut guard = self.client_server_map.lock(); + // guard.remove(&(self.process_id, self.secret_key)); + // } + // + // async fn send_and_receive_loop( + // &mut self, + // code: char, + // message: BytesMut, + // server: &mut Server, + // address: &Address, + // pool: &ConnectionPool, + // ) -> Result<(), Err> { + // debug!("Sending {} to server", code); + // + // self.send_server_message(server, message, &address, &pool) + // .await?; + // + // // Read all data the server has to offer, which can be multiple messages + // // buffered in 8196 bytes chunks. + // loop { + // let response = self.receive_server_message(server, &address, &pool).await?; + // + // match write_all_half(&mut self.write, response).await { + // Ok(_) => (), + // Err(err) => { + // server.mark_bad(); + // return Err(err); + // } + // }; + // + // if !server.is_data_available() { + // break; + // } + // } + // + // // Report query executed statistics. + // self.stats.query(self.process_id, address.id); + // + // Ok(()) + // } + // + // async fn send_server_message( + // &self, + // server: &mut Server, + // message: BytesMut, + // address: &Address, + // pool: &ConnectionPool, + // ) -> Result<(), Err> { + // match server.send(message).await { + // Ok(_) => Ok(()), + // Err(err) => { + // pool.ban(address, self.process_id); + // Err(err) + // } + // } + // } + // + // async fn receive_server_message( + // &mut self, + // server: &mut Server, + // address: &Address, + // pool: &ConnectionPool, + // ) -> Result<BytesMut, Err> { + // if pool.settings.user.statement_timeout > 0 { + // match tokio::time::timeout( + // tokio::time::Duration::from_millis(pool.settings.user.statement_timeout), + // server.recv(), + // ) + // .await + // { + // Ok(result) => match result { + // Ok(message) => Ok(message), + // Err(err) => { + // pool.ban(address, self.process_id); + // error_response_terminal( + // &mut self.write, + // &format!("error receiving data from server: {:?}", err), + // ) + // .await?; + // Err(err) + // } + // }, + // Err(_) => { + // error!( + // "Statement timeout while talking to {:?} with user {}", + // address, pool.settings.user.username + // ); + // server.mark_bad(); + // pool.ban(address, self.process_id); + // error_response_terminal(&mut self.write, "pool statement timeout").await?; + // Err(Err::StatementTimeout) + // } + // } + // } else { + // match server.recv().await { + // Ok(message) => Ok(message), + // Err(err) => { + // pool.ban(address, self.process_id); + // error_response_terminal( + // &mut self.write, + // &format!("error receiving data from server: {:?}", err), + // ) + // .await?; + // Err(err) + // } + // } + // } + // } + //} + // + //impl<S, T> Drop for Client<S, T> { + // fn drop(&mut self) { + // let mut guard = self.client_server_map.lock(); + // guard.remove(&(self.process_id, self.secret_key)); + // + // // Dirty shutdown + // // TODO: refactor, this is not the best way to handle state management. + // if let Some(address_id) = self.last_address_id { + // self.stats.client_disconnecting(self.process_id, address_id); + // + // if self.connected_to_server { + // if let Some(process_id) = self.last_server_id { + // self.stats.server_idle(process_id, address_id); + // } + // } + // } + // } + //} + +} diff --git a/lib/gat/client_test.go b/lib/gat/gatling/client/client_test.go similarity index 76% rename from lib/gat/client_test.go rename to lib/gat/gatling/client/client_test.go index b28f2da2..e535db4e 100644 --- a/lib/gat/client_test.go +++ b/lib/gat/gatling/client/client_test.go @@ -1,3 +1,3 @@ -package gat +package client // TODO: write client tests, original had none diff --git a/lib/gat/gatling/conn_pool/conn_pool.go b/lib/gat/gatling/conn_pool/conn_pool.go new file mode 100644 index 00000000..b6a9c437 --- /dev/null +++ b/lib/gat/gatling/conn_pool/conn_pool.go @@ -0,0 +1,69 @@ +package conn_pool + +import ( + "context" + "fmt" + "gfx.cafe/gfx/pggat/lib/config" + "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/gatling/server" + "gfx.cafe/gfx/pggat/lib/gat/protocol" + "log" +) + +type ConnectionPool struct { + c *config.Pool + user *config.User + pool gat.Pool + servers []*server.Server +} + +func NewConnectionPool(pool gat.Pool, conf *config.Pool, user *config.User) *ConnectionPool { + p := &ConnectionPool{ + user: user, + pool: pool, + } + p.EnsureConfig(conf) + return p +} + +func (c *ConnectionPool) EnsureConfig(conf *config.Pool) { + c.c = conf + if len(c.servers) == 0 { + // connect to a server + shard := c.c.Shards["0"] + srv := shard.Servers[0] // TODO choose a better way + s, err := server.Dial(context.Background(), fmt.Sprintf("%s:%d", srv.Host(), srv.Port()), c.user, shard.Database, nil) + if err != nil { + log.Println("error connecting to server", err) + } + c.servers = append(c.servers, s) + } +} + +func (c *ConnectionPool) GetUser() *config.User { + return c.user +} + +func (c *ConnectionPool) GetServerInfo() []*protocol.ParameterStatus { + if len(c.servers) > 0 { + return c.servers[0].GetServerInfo() + } + return nil +} + +func (c *ConnectionPool) Query(ctx context.Context, query string) (<-chan protocol.Packet, error) { + rep := make(chan protocol.Packet) + + // TODO ideally, this would look at loads, capabilities, etc and choose the server accordingly + go func() { + err := c.servers[0].Query(query, rep) + if err != nil { + log.Println(err) + } + close(rep) + }() + + return rep, nil +} + +var _ gat.ConnectionPool = (*ConnectionPool)(nil) diff --git a/lib/gat/gatling/gatling.go b/lib/gat/gatling/gatling.go new file mode 100644 index 00000000..512cfca4 --- /dev/null +++ b/lib/gat/gatling/gatling.go @@ -0,0 +1,142 @@ +package gatling + +import ( + "context" + "fmt" + "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/gatling/client" + "gfx.cafe/gfx/pggat/lib/gat/gatling/pool" + "gfx.cafe/gfx/pggat/lib/gat/protocol/pg_error" + "net" + "sync" + + "git.tuxpa.in/a/zlog/log" + + "gfx.cafe/gfx/pggat/lib/config" +) + +type Gatling struct { + c *config.Global + mu sync.RWMutex + + chConfig chan *config.Global + + pools map[string]*pool.Pool +} + +func NewGatling(conf *config.Global) *Gatling { + g := &Gatling{ + chConfig: make(chan *config.Global, 1), + pools: map[string]*pool.Pool{}, + } + err := g.ensureConfig(conf) + if err != nil { + log.Println("failed to parse config", err) + } + go g.watchConfigs() + return g +} + +func (g *Gatling) watchConfigs() { + for { + c := <-g.chConfig + err := g.ensureConfig(c) + if err != nil { + log.Println("failed to parse config", err) + } + } +} + +func (g *Gatling) GetPool(name string) (gat.Pool, error) { + g.mu.RLock() + defer g.mu.RUnlock() + srv, ok := g.pools[name] + if !ok { + return nil, fmt.Errorf("pool '%s' not found", name) + } + return srv, nil +} + +func (g *Gatling) ensureConfig(c *config.Global) error { + g.mu.Lock() + defer g.mu.Unlock() + + g.c = c + + if err := g.ensureGeneral(c); err != nil { + return err + } + if err := g.ensureAdmin(c); err != nil { + return err + } + if err := g.ensurePools(c); err != nil { + return err + } + + return nil +} + +// TODO: all other settings +func (g *Gatling) ensureGeneral(c *config.Global) error { + return nil +} + +// TODO: should configure the admin things, metrics, etc +func (g *Gatling) ensureAdmin(c *config.Global) error { + return nil +} + +// TODO: should connect to & load pools from config +func (g *Gatling) ensurePools(c *config.Global) error { + for name, p := range c.Pools { + if existing, ok := g.pools[name]; ok { + existing.EnsureConfig(&p) + } else { + g.pools[name] = pool.NewPool(&p) + } + } + return nil +} + +func (g *Gatling) ListenAndServe(ctx context.Context) error { + ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", g.c.General.Host, g.c.General.Port)) + if err != nil { + return err + } + for { + var c net.Conn + c, err = ln.Accept() + if err != nil { + return err + } + go func() { + err := g.handleConnection(ctx, c) + if err != nil { + log.Println("disconnected:", err) + } + }() + } +} + +// TODO: TLS +func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { + cl := client.NewClient(g, g.c, c, false) + err := cl.Accept(ctx) + if err != nil { + log.Println(err.Error()) + switch e := err.(type) { + case *pg_error.Error: + return cl.Send(e.Packet()) + default: + pgErr := &pg_error.Error{ + Severity: pg_error.Err, + Code: pg_error.InternalError, + Message: e.Error(), + } + return cl.Send(pgErr.Packet()) + } + } + return nil +} + +var _ gat.Gat = (*Gatling)(nil) diff --git a/lib/gat/messages.go b/lib/gat/gatling/messages/messages.go similarity index 99% rename from lib/gat/messages.go rename to lib/gat/gatling/messages/messages.go index 86544c8c..836c72b8 100644 --- a/lib/gat/messages.go +++ b/lib/gat/gatling/messages/messages.go @@ -1,4 +1,4 @@ -package gat +package messages import ( "crypto/md5" diff --git a/lib/gat/messages_test.go b/lib/gat/gatling/messages/messages_test.go similarity index 79% rename from lib/gat/messages_test.go rename to lib/gat/gatling/messages/messages_test.go index 4fab8064..297aa034 100644 --- a/lib/gat/messages_test.go +++ b/lib/gat/gatling/messages/messages_test.go @@ -1,3 +1,3 @@ -package gat +package messages // TODO: once we decide on what messages, write the relevant test diff --git a/lib/gat/gatling/pool/pool.go b/lib/gat/gatling/pool/pool.go new file mode 100644 index 00000000..4325385c --- /dev/null +++ b/lib/gat/gatling/pool/pool.go @@ -0,0 +1,544 @@ +package pool + +import ( + "fmt" + "gfx.cafe/gfx/pggat/lib/config" + "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/gatling/conn_pool" +) + +type Pool struct { + c *config.Pool + users map[string]config.User + connPools map[string]*conn_pool.ConnectionPool +} + +func NewPool(conf *config.Pool) *Pool { + pool := &Pool{ + connPools: make(map[string]*conn_pool.ConnectionPool), + } + pool.EnsureConfig(conf) + return pool +} + +func (p *Pool) EnsureConfig(conf *config.Pool) { + p.c = conf + p.users = make(map[string]config.User) + for _, user := range conf.Users { + p.users[user.Name] = user + } + // ensure conn pools + for name, user := range p.users { + if existing, ok := p.connPools[name]; ok { + existing.EnsureConfig(conf) + } else { + u := user + p.connPools[name] = conn_pool.NewConnectionPool(p, conf, &u) + } + } +} + +func (p *Pool) GetUser(name string) (*config.User, error) { + user, ok := p.users[name] + if !ok { + return nil, fmt.Errorf("user '%s' not found", name) + } + return &user, nil +} + +func (p *Pool) WithUser(name string) (gat.ConnectionPool, error) { + pool, ok := p.connPools[name] + if !ok { + return nil, fmt.Errorf("no pool for '%s'", name) + } + return pool, nil +} + +var _ gat.Pool = (*Pool)(nil) + +//TODO: implement server pool +//#[async_trait] +//impl ManageConnection for ServerPool { +// type Connection = Server; +// type Err = Err; +// +// /// Attempts to create a new connection. +// async fn connect(&self) -> Result<Self::Connection, Self::Err> { +// info!( +// "Creating a new connection to {:?} using user {:?}", +// self.address.name(), +// self.user.username +// ); +// +// // Put a temporary process_id into the stats +// // for server login. +// let process_id = rand::random::<i32>(); +// self.stats.server_login(process_id, self.address.id); +// +// // Connect to the PostgreSQL server. +// match Server::startup( +// &self.address, +// &self.user, +// &self.database, +// self.client_server_map.clone(), +// self.stats.clone(), +// ) +// .await +// { +// Ok(conn) => { +// // Remove the temporary process_id from the stats. +// self.stats.server_disconnecting(process_id, self.address.id); +// Ok(conn) +// } +// Err(err) => { +// // Remove the temporary process_id from the stats. +// self.stats.server_disconnecting(process_id, self.address.id); +// Err(err) +// } +// } +// } +// +// /// Determines if the connection is still connected to the database. +// async fn is_valid(&self, _conn: &mut PooledConnection<'_, Self>) -> Result<(), Self::Err> { +// Ok(()) +// } +// +// /// Synchronously determine if the connection is no longer usable, if possible. +// fn has_broken(&self, conn: &mut Self::Connection) -> bool { +// conn.is_bad() +// } +//} +// +///// Get the connection pool +//pub fn get_pool(db: String, user: String) -> Option<ConnectionPool> { +// match get_all_pools().get(&(db, user)) { +// Some(pool) => Some(pool.clone()), +// None => None, +// } +//} +// +///// How many total servers we have in the config. +//pub fn get_number_of_addresses() -> usize { +// get_all_pools() +// .iter() +// .map(|(_, pool)| pool.databases()) +// .sum() +//} +// +///// Get a pointer to all configured pools. +//pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> { +// return (*(*POOLS.load())).clone(); +//} + +//TODO: implement this +// /// Construct the connection pool from the configuration. +// func (c *ConnectionPool) from_config(client_server_map: ClientServerMap) Result<(), Err> { +// let config = get_config() +// +// new_pools = HashMap::new() +// address_id = 0 +// +// for (pool_name, pool_config) in &config.pools { +// // There is one pool per database/user pair. +// for (_, user) in &pool_config.users { +// shards = Vec::new() +// addresses = Vec::new() +// banlist = Vec::new() +// shard_ids = pool_config +// .shards +// .clone() +// .into_keys() +// .map(|x| x.to_string()) +// .collect::<Vec<string>>() +// +// // Sort by shard number to ensure consistency. +// shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap()) +// +// for shard_idx in &shard_ids { +// let shard = &pool_config.shards[shard_idx] +// pools = Vec::new() +// servers = Vec::new() +// address_index = 0 +// replica_number = 0 +// +// for server in shard.servers.iter() { +// let role = match server.2.as_ref() { +// "primary" => Role::Primary, +// "replica" => Role::Replica, +// _ => { +// error!("Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2) +// Role::Replica +// } +// } +// +// let address = Address { +// id: address_id, +// database: shard.database.clone(), +// host: server.0.clone(), +// port: server.1 as u16, +// role: role, +// address_index, +// replica_number, +// shard: shard_idx.parse::<usize>().unwrap(), +// username: user.username.clone(), +// pool_name: pool_name.clone(), +// } +// +// address_id += 1 +// address_index += 1 +// +// if role == Role::Replica { +// replica_number += 1 +// } +// +// let manager = ServerPool::new( +// address.clone(), +// user.clone(), +// &shard.database, +// client_server_map.clone(), +// get_reporter(), +// ) +// +// let pool = Pool::builder() +// .max_size(user.pool_size) +// .connection_timeout(std::time::Duration::from_millis( +// config.general.connect_timeout, +// )) +// .test_on_check_out(false) +// .build(manager) +// .await +// .unwrap() +// +// pools.push(pool) +// servers.push(address) +// } +// +// shards.push(pools) +// addresses.push(servers) +// banlist.push(HashMap::new()) +// } +// +// assert_eq!(shards.len(), addresses.len()) +// +// pool = ConnectionPool { +// databases: shards, +// addresses: addresses, +// banlist: Arc::new(RwLock::new(banlist)), +// stats: get_reporter(), +// server_info: BytesMut::new(), +// settings: PoolSettings { +// pool_mode: match pool_config.pool_mode.as_str() { +// "transaction" => PoolMode::Transaction, +// "session" => PoolMode::Session, +// _ => unreachable!(), +// }, +// // shards: pool_config.shards.clone(), +// shards: shard_ids.len(), +// user: user.clone(), +// default_role: match pool_config.default_role.as_str() { +// "any" => None, +// "replica" => Some(Role::Replica), +// "primary" => Some(Role::Primary), +// _ => unreachable!(), +// }, +// query_parser_enabled: pool_config.query_parser_enabled.clone(), +// primary_reads_enabled: pool_config.primary_reads_enabled, +// sharding_function: match pool_config.sharding_function.as_str() { +// "pg_bigint_hash" => ShardingFunction::PgBigintHash, +// "sha1" => ShardingFunction::Sha1, +// _ => unreachable!(), +// }, +// }, +// } +// +// // Connect to the servers to make sure pool configuration is valid +// // before setting it globally. +// match pool.validate().await { +// Ok(_) => (), +// Err(err) => { +// error!("Could not validate connection pool: {:?}", err) +// return Err(err) +// } +// } +// +// // There is one pool per database/user pair. +// new_pools.insert((pool_name.clone(), user.username.clone()), pool) +// } +// } +// +// POOLS.store(Arc::new(new_pools.clone())) +// +// Ok(()) +// } +// +// /// Connect to all shards and grab server information. +// /// Return server information we will pass to the clients +// /// when they connect. +// /// This also warms up the pool for clients that connect when +// /// the pooler starts up. +// async fn validate(&mut self) Result<(), Err> { +// server_infos = Vec::new() +// for shard in 0..self.shards() { +// for server in 0..self.servers(shard) { +// let connection = match self.databases[shard][server].get().await { +// Ok(conn) => conn, +// Err(err) => { +// error!("Shard {} down or misconfigured: {:?}", shard, err) +// continue +// } +// } +// +// let proxy = connection +// let server = &*proxy +// let server_info = server.server_info() +// +// if server_infos.len() > 0 { +// // Compare against the last server checked. +// if server_info != server_infos[server_infos.len() - 1] { +// warn!( +// "{:?} has different server configuration than the last server", +// proxy.address() +// ) +// } +// } +// +// server_infos.push(server_info) +// } +// } +// +// // TODO: compare server information to make sure +// // all shards are running identical configurations. +// if server_infos.len() == 0 { +// return Err(Err::AllServersDown) +// } +// +// // We're assuming all servers are identical. +// // TODO: not true. +// self.server_info = server_infos[0].clone() +// +// Ok(()) +// } +// +// /// Get a connection from the pool. +// func (c *ConnectionPool) get( +// &self, +// shard: usize, // shard number +// role: Option<Role>, // primary or replica +// process_id: i32, // client id +// ) Result<(PooledConnection<'_, ServerPool>, Address), Err> { +// let now = Instant::now() +// candidates: Vec<&Address> = self.addresses[shard] +// .iter() +// .filter(|address| address.role == role) +// .collect() +// +// // Random load balancing +// candidates.shuffle(&mut thread_rng()) +// +// let healthcheck_timeout = get_config().general.healthcheck_timeout +// let healthcheck_delay = get_config().general.healthcheck_delay as u128 +// +// while !candidates.is_empty() { +// // Get the next candidate +// let address = match candidates.pop() { +// Some(address) => address, +// None => break, +// } +// +// if self.is_banned(&address, role) { +// debug!("Address {:?} is banned", address) +// continue +// } +// +// // Indicate we're waiting on a server connection from a pool. +// self.stats.client_waiting(process_id, address.id) +// +// // Check if we can connect +// conn = match self.databases[address.shard][address.address_index] +// .get() +// .await +// { +// Ok(conn) => conn, +// Err(err) => { +// error!("Banning instance {:?}, error: {:?}", address, err) +// self.ban(&address, process_id) +// self.stats +// .checkout_time(now.elapsed().as_micros(), process_id, address.id) +// continue +// } +// } +// +// // // Check if this server is alive with a health check. +// let server = &mut *conn +// +// // Will return error if timestamp is greater than current system time, which it should never be set to +// let require_healthcheck = +// server.last_activity().elapsed().unwrap().as_millis() > healthcheck_delay +// +// // Do not issue a health check unless it's been a little while +// // since we last checked the server is ok. +// // Health checks are pretty expensive. +// if !require_healthcheck { +// self.stats +// .checkout_time(now.elapsed().as_micros(), process_id, address.id) +// self.stats.server_active(conn.process_id(), address.id) +// return Ok((conn, address.clone())) +// } +// +// debug!("Running health check on server {:?}", address) +// +// self.stats.server_tested(server.process_id(), address.id) +// +// match tokio::time::timeout( +// tokio::time::Duration::from_millis(healthcheck_timeout), +// server.query(""), // Cheap query (query parser not used in PG) +// ) +// .await +// { +// // Check if health check succeeded. +// Ok(res) => match res { +// Ok(_) => { +// self.stats +// .checkout_time(now.elapsed().as_micros(), process_id, address.id) +// self.stats.server_active(conn.process_id(), address.id) +// return Ok((conn, address.clone())) +// } +// +// // Health check failed. +// Err(err) => { +// error!( +// "Banning instance {:?} because of failed health check, {:?}", +// address, err +// ) +// +// // Don't leave a bad connection in the pool. +// server.mark_bad() +// +// self.ban(&address, process_id) +// continue +// } +// }, +// +// // Health check timed out. +// Err(err) => { +// error!( +// "Banning instance {:?} because of health check timeout, {:?}", +// address, err +// ) +// // Don't leave a bad connection in the pool. +// server.mark_bad() +// +// self.ban(&address, process_id) +// continue +// } +// } +// } +// +// Err(Err::AllServersDown) +// } +// +// /// Ban an address (i.e. replica). It no longer will serve +// /// traffic for any new transactions. Existing transactions on that replica +// /// will finish successfully or error out to the clients. +// func (c *ConnectionPool) ban(&self, address: &Address, process_id: i32) { +// self.stats.client_disconnecting(process_id, address.id) +// +// error!("Banning {:?}", address) +// +// let now = chrono::offset::Utc::now().naive_utc() +// guard = self.banlist.write() +// guard[address.shard].insert(address.clone(), now) +// } +// +// /// Clear the replica to receive traffic again. Takes effect immediately +// /// for all new transactions. +// func (c *ConnectionPool) _unban(&self, address: &Address) { +// guard = self.banlist.write() +// guard[address.shard].remove(address) +// } +// +// /// Check if a replica can serve traffic. If all replicas are banned, +// /// we unban all of them. Better to try then not to. +// func (c *ConnectionPool) is_banned(&self, address: &Address, role: Option<Role>) bool { +// let replicas_available = match role { +// Some(Role::Replica) => self.addresses[address.shard] +// .iter() +// .filter(|addr| addr.role == Role::Replica) +// .count(), +// None => self.addresses[address.shard].len(), +// Some(Role::Primary) => return false, // Primary cannot be banned. +// } +// +// debug!("Available targets for {:?}: {}", role, replicas_available) +// +// let guard = self.banlist.read() +// +// // Everything is banned = nothing is banned. +// if guard[address.shard].len() == replicas_available { +// drop(guard) +// guard = self.banlist.write() +// guard[address.shard].clear() +// drop(guard) +// warn!("Unbanning all replicas.") +// return false +// } +// +// // I expect this to miss 99.9999% of the time. +// match guard[address.shard].get(address) { +// Some(timestamp) => { +// let now = chrono::offset::Utc::now().naive_utc() +// let config = get_config() +// +// // Ban expired. +// if now.timestamp() - timestamp.timestamp() > config.general.ban_time { +// drop(guard) +// warn!("Unbanning {:?}", address) +// guard = self.banlist.write() +// guard[address.shard].remove(address) +// false +// } else { +// debug!("{:?} is banned", address) +// true +// } +// } +// +// None => { +// debug!("{:?} is ok", address) +// false +// } +// } +// } +// +// /// Get the number of configured shards. +// func (c *ConnectionPool) shards(&self) usize { +// self.databases.len() +// } +// +// /// Get the number of servers (primary and replicas) +// /// configured for a shard. +// func (c *ConnectionPool) servers(&self, shard: usize) usize { +// self.addresses[shard].len() +// } +// +// /// Get the total number of servers (databases) we are connected to. +// func (c *ConnectionPool) databases(&self) usize { +// databases = 0 +// for shard in 0..self.shards() { +// databases += self.servers(shard) +// } +// databases +// } +// +// /// Get pool state for a particular shard server as reported by bb8. +// func (c *ConnectionPool) pool_state(&self, shard: usize, server: usize) bb8::State { +// self.databases[shard][server].state() +// } +// +// /// Get the address information for a shard server. +// func (c *ConnectionPool) address(&self, shard: usize, server: usize) &Address { +// &self.addresses[shard][server] +// } +// +// func (c *ConnectionPool) server_info(&self) BytesMut { +// self.server_info.clone() +// } diff --git a/lib/gat/pool_test.go b/lib/gat/gatling/pool/pool_test.go similarity index 77% rename from lib/gat/pool_test.go rename to lib/gat/gatling/pool/pool_test.go index b9a0eb6a..9fbebf40 100644 --- a/lib/gat/pool_test.go +++ b/lib/gat/gatling/pool/pool_test.go @@ -1,3 +1,3 @@ -package gat +package pool // TODO: no tests, we need to write our own diff --git a/lib/gat/query_router.go b/lib/gat/gatling/query_router/query_router.go similarity index 97% rename from lib/gat/query_router.go rename to lib/gat/gatling/query_router/query_router.go index 41eb86c4..bfaaea2a 100644 --- a/lib/gat/query_router.go +++ b/lib/gat/gatling/query_router/query_router.go @@ -1,4 +1,4 @@ -package gat +package query_router import ( "fmt" @@ -50,14 +50,17 @@ type CommandShowPrimaryReads struct{} type QueryRouter struct { active_shard int primary_reads_enabled bool - pool_settings PoolSettings + //pool_settings pool.PoolSettings } +/* TODO // / Pool settings can change because of a config reload. -func (r *QueryRouter) UpdatePoolSettings(pool_settings PoolSettings) { +func (r *QueryRouter) UpdatePoolSettings(pool_settings pool.PoolSettings) { r.pool_settings = pool_settings } +*/ + // / Try to parse a command and execute it. // TODO: needs to just provide the execution function and so gatling can then plug in the client, server, etc func (r *QueryRouter) try_execute_command(pkt *protocol.Query) (Command, string) { diff --git a/lib/gat/query_router_test.go b/lib/gat/gatling/query_router/query_router_test.go similarity index 99% rename from lib/gat/query_router_test.go rename to lib/gat/gatling/query_router/query_router_test.go index a84edaa8..8d294888 100644 --- a/lib/gat/query_router_test.go +++ b/lib/gat/gatling/query_router/query_router_test.go @@ -1,4 +1,4 @@ -package gat +package query_router import ( "testing" diff --git a/lib/gat/server.go b/lib/gat/gatling/server/server.go similarity index 84% rename from lib/gat/server.go rename to lib/gat/gatling/server/server.go index 2ac8c13e..874da5ee 100644 --- a/lib/gat/server.go +++ b/lib/gat/gatling/server/server.go @@ -1,10 +1,11 @@ -package gat +package server import ( "bufio" "bytes" "encoding/binary" "fmt" + "gfx.cafe/gfx/pggat/lib/gat/protocol/pg_error" "io" "net" "time" @@ -51,7 +52,7 @@ type Server struct { var ENDIAN = binary.BigEndian -func DialServer(ctx context.Context, addr string, user *config.User, db string, stats any) (*Server, error) { +func Dial(ctx context.Context, addr string, user *config.User, db string, stats any) (*Server, error) { s := &Server{} var err error s.conn, err = net.Dial("tcp", addr) @@ -61,7 +62,6 @@ func DialServer(ctx context.Context, addr string, user *config.User, db string, s.remote = s.conn.RemoteAddr() s.r = bufio.NewReader(s.conn) s.wr = s.conn - s.server_info = []*protocol.ParameterStatus{} s.user = *user s.db = db @@ -73,6 +73,10 @@ func DialServer(ctx context.Context, addr string, user *config.User, db string, return s, s.connect(ctx) } +func (s *Server) GetServerInfo() []*protocol.ParameterStatus { + return s.server_info +} + func (s *Server) startup(ctx context.Context) error { s.log.Debug().Msg("sending startup") start := new(protocol.StartupMessage) @@ -94,6 +98,7 @@ func (s *Server) startup(ctx context.Context) error { } return nil } + func (s *Server) connect(ctx context.Context) error { err := s.startup(ctx) if err != nil { @@ -171,7 +176,7 @@ func (s *Server) connect(ctx context.Context) error { s.log.Debug().Str("method", "scram256").Msg("sasl success") } case *protocol.ErrorResponse: - pgErr := new(PostgresError) + pgErr := new(pg_error.Error) pgErr.Read(p) return pgErr case *protocol.ParameterStatus: @@ -189,6 +194,35 @@ func (s *Server) connect(ctx context.Context) error { } } +func (s *Server) Query(query string, rep chan<- protocol.Packet) error { + // send to server + q := new(protocol.Query) + q.Fields.Query = query + _, err := q.Write(s.wr) + if err != nil { + return err + } + + // read responses + for { + var rsp protocol.Packet + rsp, err = protocol.ReadBackend(s.r) + if err != nil { + return err + } + switch r := rsp.(type) { + case *protocol.ReadyForQuery: + if r.Fields.Status == 'I' { + rep <- rsp + return nil + } + case *protocol.CopyInResponse, *protocol.CopyOutResponse, *protocol.CopyBothResponse: + return fmt.Errorf("unsuported") + } + rep <- rsp + } +} + func (s *Server) Close(ctx context.Context) error { <-ctx.Done() return nil diff --git a/lib/gat/server_test.go b/lib/gat/gatling/server/server_test.go similarity index 65% rename from lib/gat/server_test.go rename to lib/gat/gatling/server/server_test.go index 2e3a0a42..ad66505c 100644 --- a/lib/gat/server_test.go +++ b/lib/gat/gatling/server/server_test.go @@ -1,7 +1,8 @@ -package gat +package server import ( "context" + "gfx.cafe/gfx/pggat/lib/gat/gatling/client" "git.tuxpa.in/a/zlog/log" "testing" @@ -18,8 +19,8 @@ var test_user = config.User{ } func TestServerDial(t *testing.T) { - csm := make(map[ClientInfo]ClientInfo) - srv, err := DialServer(context.TODO(), test_address, &test_user, "postgres", csm, nil) + csm := make(map[client.ClientInfo]client.ClientInfo) + srv, err := Dial(context.TODO(), test_address, &test_user, "postgres", csm, nil) if err != nil { t.Error(err) } diff --git a/lib/gat/sharding.go b/lib/gat/gatling/sharding/sharding.go similarity index 99% rename from lib/gat/sharding.go rename to lib/gat/gatling/sharding/sharding.go index 4deb7267..8a471a2e 100644 --- a/lib/gat/sharding.go +++ b/lib/gat/gatling/sharding/sharding.go @@ -1,4 +1,4 @@ -package gat +package sharding const PARTITION_HASH_SEED = 0x7A5B22367996DCFD diff --git a/lib/gat/sharding_test.go b/lib/gat/gatling/sharding/sharding_test.go similarity index 99% rename from lib/gat/sharding_test.go rename to lib/gat/gatling/sharding/sharding_test.go index 23bdd97b..356dd2cd 100644 --- a/lib/gat/sharding_test.go +++ b/lib/gat/gatling/sharding/sharding_test.go @@ -1,4 +1,4 @@ -package gat +package sharding //TODO: convert test diff --git a/lib/gat/stats.go b/lib/gat/gatling/stats/stats.go similarity index 99% rename from lib/gat/stats.go rename to lib/gat/gatling/stats/stats.go index 8daa678f..edeae5f5 100644 --- a/lib/gat/stats.go +++ b/lib/gat/gatling/stats/stats.go @@ -1,4 +1,4 @@ -package gat +package stats //TODO: metrics // let's do this last. we can use the go package for prometheus, its way better than anything we could do diff --git a/lib/gat/pool.go b/lib/gat/pool.go index 6fc53268..07a71492 100644 --- a/lib/gat/pool.go +++ b/lib/gat/pool.go @@ -1,566 +1,8 @@ package gat -import ( - "gfx.cafe/gfx/pggat/lib/config" -) +import "gfx.cafe/gfx/pggat/lib/config" -type PoolSettings struct { - /// Transaction or Session. - pool_mode config.PoolMode - - // Number of shards. - shards int - - // Connecting user. - user config.User - - // Default server role to connect to. - default_role config.ServerRole - - // Enable/disable query parser. - query_parser_enabled bool - - // Read from the primary as well or not. - primary_reads_enabled bool - - // Sharding function. - sharding_function ShardFunc -} - -func DefaultPool() PoolSettings { - return PoolSettings{ - pool_mode: config.POOLMODE_TXN, - shards: 1, - user: config.User{ - Name: "postgres", - Password: "test", - }, - default_role: config.SERVERROLE_NONE, - query_parser_enabled: false, - primary_reads_enabled: true, - //TODO: select default sharding function - sharding_function: nil, - } -} - -type ServerPool struct { - address string - user config.User - database string - client_server_map map[ClientKey]ClientInfo - stats any //TODO: stats +type Pool interface { + GetUser(name string) (*config.User, error) + WithUser(name string) (ConnectionPool, error) } - -//TODO: implement server pool -//#[async_trait] -//impl ManageConnection for ServerPool { -// type Connection = Server; -// type Error = Error; -// -// /// Attempts to create a new connection. -// async fn connect(&self) -> Result<Self::Connection, Self::Error> { -// info!( -// "Creating a new connection to {:?} using user {:?}", -// self.address.name(), -// self.user.username -// ); -// -// // Put a temporary process_id into the stats -// // for server login. -// let process_id = rand::random::<i32>(); -// self.stats.server_login(process_id, self.address.id); -// -// // Connect to the PostgreSQL server. -// match Server::startup( -// &self.address, -// &self.user, -// &self.database, -// self.client_server_map.clone(), -// self.stats.clone(), -// ) -// .await -// { -// Ok(conn) => { -// // Remove the temporary process_id from the stats. -// self.stats.server_disconnecting(process_id, self.address.id); -// Ok(conn) -// } -// Err(err) => { -// // Remove the temporary process_id from the stats. -// self.stats.server_disconnecting(process_id, self.address.id); -// Err(err) -// } -// } -// } -// -// /// Determines if the connection is still connected to the database. -// async fn is_valid(&self, _conn: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> { -// Ok(()) -// } -// -// /// Synchronously determine if the connection is no longer usable, if possible. -// fn has_broken(&self, conn: &mut Self::Connection) -> bool { -// conn.is_bad() -// } -//} -// -///// Get the connection pool -//pub fn get_pool(db: String, user: String) -> Option<ConnectionPool> { -// match get_all_pools().get(&(db, user)) { -// Some(pool) => Some(pool.clone()), -// None => None, -// } -//} -// -///// How many total servers we have in the config. -//pub fn get_number_of_addresses() -> usize { -// get_all_pools() -// .iter() -// .map(|(_, pool)| pool.databases()) -// .sum() -//} -// -///// Get a pointer to all configured pools. -//pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> { -// return (*(*POOLS.load())).clone(); -//} - -type ConnectionPool struct { - /// The pools handled internally by bb8. - // TODO: https://docs.rs/bb8-bolt/latest/src/bb8_bolt/lib.rs.html - databases [][]ServerPool - - /// The addresses (host, port, role) to handle - /// failover and load balancing deterministically. - addresses []config.Server - - /// List of banned addresses (see above) - /// that should not be queried. - banlist []string - - /// The statistics aggregator runs in a separate task - /// and receives stats from clients, servers, and the pool. - stats any // TODO: stats - - /// The server information (K messages) have to be passed to the - /// clients on startup. We pre-connect to all shards and replicas - /// on pool creation and save the K messages here. - /// TODO: consider storing this in a better format - server_info []byte - - /// Pool configuration. - settings PoolSettings -} - -//TODO: implement this -// /// Construct the connection pool from the configuration. -// func (c *ConnectionPool) from_config(client_server_map: ClientServerMap) Result<(), Error> { -// let config = get_config() -// -// new_pools = HashMap::new() -// address_id = 0 -// -// for (pool_name, pool_config) in &config.pools { -// // There is one pool per database/user pair. -// for (_, user) in &pool_config.users { -// shards = Vec::new() -// addresses = Vec::new() -// banlist = Vec::new() -// shard_ids = pool_config -// .shards -// .clone() -// .into_keys() -// .map(|x| x.to_string()) -// .collect::<Vec<string>>() -// -// // Sort by shard number to ensure consistency. -// shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap()) -// -// for shard_idx in &shard_ids { -// let shard = &pool_config.shards[shard_idx] -// pools = Vec::new() -// servers = Vec::new() -// address_index = 0 -// replica_number = 0 -// -// for server in shard.servers.iter() { -// let role = match server.2.as_ref() { -// "primary" => Role::Primary, -// "replica" => Role::Replica, -// _ => { -// error!("Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2) -// Role::Replica -// } -// } -// -// let address = Address { -// id: address_id, -// database: shard.database.clone(), -// host: server.0.clone(), -// port: server.1 as u16, -// role: role, -// address_index, -// replica_number, -// shard: shard_idx.parse::<usize>().unwrap(), -// username: user.username.clone(), -// pool_name: pool_name.clone(), -// } -// -// address_id += 1 -// address_index += 1 -// -// if role == Role::Replica { -// replica_number += 1 -// } -// -// let manager = ServerPool::new( -// address.clone(), -// user.clone(), -// &shard.database, -// client_server_map.clone(), -// get_reporter(), -// ) -// -// let pool = Pool::builder() -// .max_size(user.pool_size) -// .connection_timeout(std::time::Duration::from_millis( -// config.general.connect_timeout, -// )) -// .test_on_check_out(false) -// .build(manager) -// .await -// .unwrap() -// -// pools.push(pool) -// servers.push(address) -// } -// -// shards.push(pools) -// addresses.push(servers) -// banlist.push(HashMap::new()) -// } -// -// assert_eq!(shards.len(), addresses.len()) -// -// pool = ConnectionPool { -// databases: shards, -// addresses: addresses, -// banlist: Arc::new(RwLock::new(banlist)), -// stats: get_reporter(), -// server_info: BytesMut::new(), -// settings: PoolSettings { -// pool_mode: match pool_config.pool_mode.as_str() { -// "transaction" => PoolMode::Transaction, -// "session" => PoolMode::Session, -// _ => unreachable!(), -// }, -// // shards: pool_config.shards.clone(), -// shards: shard_ids.len(), -// user: user.clone(), -// default_role: match pool_config.default_role.as_str() { -// "any" => None, -// "replica" => Some(Role::Replica), -// "primary" => Some(Role::Primary), -// _ => unreachable!(), -// }, -// query_parser_enabled: pool_config.query_parser_enabled.clone(), -// primary_reads_enabled: pool_config.primary_reads_enabled, -// sharding_function: match pool_config.sharding_function.as_str() { -// "pg_bigint_hash" => ShardingFunction::PgBigintHash, -// "sha1" => ShardingFunction::Sha1, -// _ => unreachable!(), -// }, -// }, -// } -// -// // Connect to the servers to make sure pool configuration is valid -// // before setting it globally. -// match pool.validate().await { -// Ok(_) => (), -// Err(err) => { -// error!("Could not validate connection pool: {:?}", err) -// return Err(err) -// } -// } -// -// // There is one pool per database/user pair. -// new_pools.insert((pool_name.clone(), user.username.clone()), pool) -// } -// } -// -// POOLS.store(Arc::new(new_pools.clone())) -// -// Ok(()) -// } -// -// /// Connect to all shards and grab server information. -// /// Return server information we will pass to the clients -// /// when they connect. -// /// This also warms up the pool for clients that connect when -// /// the pooler starts up. -// async fn validate(&mut self) Result<(), Error> { -// server_infos = Vec::new() -// for shard in 0..self.shards() { -// for server in 0..self.servers(shard) { -// let connection = match self.databases[shard][server].get().await { -// Ok(conn) => conn, -// Err(err) => { -// error!("Shard {} down or misconfigured: {:?}", shard, err) -// continue -// } -// } -// -// let proxy = connection -// let server = &*proxy -// let server_info = server.server_info() -// -// if server_infos.len() > 0 { -// // Compare against the last server checked. -// if server_info != server_infos[server_infos.len() - 1] { -// warn!( -// "{:?} has different server configuration than the last server", -// proxy.address() -// ) -// } -// } -// -// server_infos.push(server_info) -// } -// } -// -// // TODO: compare server information to make sure -// // all shards are running identical configurations. -// if server_infos.len() == 0 { -// return Err(Error::AllServersDown) -// } -// -// // We're assuming all servers are identical. -// // TODO: not true. -// self.server_info = server_infos[0].clone() -// -// Ok(()) -// } -// -// /// Get a connection from the pool. -// func (c *ConnectionPool) get( -// &self, -// shard: usize, // shard number -// role: Option<Role>, // primary or replica -// process_id: i32, // client id -// ) Result<(PooledConnection<'_, ServerPool>, Address), Error> { -// let now = Instant::now() -// candidates: Vec<&Address> = self.addresses[shard] -// .iter() -// .filter(|address| address.role == role) -// .collect() -// -// // Random load balancing -// candidates.shuffle(&mut thread_rng()) -// -// let healthcheck_timeout = get_config().general.healthcheck_timeout -// let healthcheck_delay = get_config().general.healthcheck_delay as u128 -// -// while !candidates.is_empty() { -// // Get the next candidate -// let address = match candidates.pop() { -// Some(address) => address, -// None => break, -// } -// -// if self.is_banned(&address, role) { -// debug!("Address {:?} is banned", address) -// continue -// } -// -// // Indicate we're waiting on a server connection from a pool. -// self.stats.client_waiting(process_id, address.id) -// -// // Check if we can connect -// conn = match self.databases[address.shard][address.address_index] -// .get() -// .await -// { -// Ok(conn) => conn, -// Err(err) => { -// error!("Banning instance {:?}, error: {:?}", address, err) -// self.ban(&address, process_id) -// self.stats -// .checkout_time(now.elapsed().as_micros(), process_id, address.id) -// continue -// } -// } -// -// // // Check if this server is alive with a health check. -// let server = &mut *conn -// -// // Will return error if timestamp is greater than current system time, which it should never be set to -// let require_healthcheck = -// server.last_activity().elapsed().unwrap().as_millis() > healthcheck_delay -// -// // Do not issue a health check unless it's been a little while -// // since we last checked the server is ok. -// // Health checks are pretty expensive. -// if !require_healthcheck { -// self.stats -// .checkout_time(now.elapsed().as_micros(), process_id, address.id) -// self.stats.server_active(conn.process_id(), address.id) -// return Ok((conn, address.clone())) -// } -// -// debug!("Running health check on server {:?}", address) -// -// self.stats.server_tested(server.process_id(), address.id) -// -// match tokio::time::timeout( -// tokio::time::Duration::from_millis(healthcheck_timeout), -// server.query(""), // Cheap query (query parser not used in PG) -// ) -// .await -// { -// // Check if health check succeeded. -// Ok(res) => match res { -// Ok(_) => { -// self.stats -// .checkout_time(now.elapsed().as_micros(), process_id, address.id) -// self.stats.server_active(conn.process_id(), address.id) -// return Ok((conn, address.clone())) -// } -// -// // Health check failed. -// Err(err) => { -// error!( -// "Banning instance {:?} because of failed health check, {:?}", -// address, err -// ) -// -// // Don't leave a bad connection in the pool. -// server.mark_bad() -// -// self.ban(&address, process_id) -// continue -// } -// }, -// -// // Health check timed out. -// Err(err) => { -// error!( -// "Banning instance {:?} because of health check timeout, {:?}", -// address, err -// ) -// // Don't leave a bad connection in the pool. -// server.mark_bad() -// -// self.ban(&address, process_id) -// continue -// } -// } -// } -// -// Err(Error::AllServersDown) -// } -// -// /// Ban an address (i.e. replica). It no longer will serve -// /// traffic for any new transactions. Existing transactions on that replica -// /// will finish successfully or error out to the clients. -// func (c *ConnectionPool) ban(&self, address: &Address, process_id: i32) { -// self.stats.client_disconnecting(process_id, address.id) -// -// error!("Banning {:?}", address) -// -// let now = chrono::offset::Utc::now().naive_utc() -// guard = self.banlist.write() -// guard[address.shard].insert(address.clone(), now) -// } -// -// /// Clear the replica to receive traffic again. Takes effect immediately -// /// for all new transactions. -// func (c *ConnectionPool) _unban(&self, address: &Address) { -// guard = self.banlist.write() -// guard[address.shard].remove(address) -// } -// -// /// Check if a replica can serve traffic. If all replicas are banned, -// /// we unban all of them. Better to try then not to. -// func (c *ConnectionPool) is_banned(&self, address: &Address, role: Option<Role>) bool { -// let replicas_available = match role { -// Some(Role::Replica) => self.addresses[address.shard] -// .iter() -// .filter(|addr| addr.role == Role::Replica) -// .count(), -// None => self.addresses[address.shard].len(), -// Some(Role::Primary) => return false, // Primary cannot be banned. -// } -// -// debug!("Available targets for {:?}: {}", role, replicas_available) -// -// let guard = self.banlist.read() -// -// // Everything is banned = nothing is banned. -// if guard[address.shard].len() == replicas_available { -// drop(guard) -// guard = self.banlist.write() -// guard[address.shard].clear() -// drop(guard) -// warn!("Unbanning all replicas.") -// return false -// } -// -// // I expect this to miss 99.9999% of the time. -// match guard[address.shard].get(address) { -// Some(timestamp) => { -// let now = chrono::offset::Utc::now().naive_utc() -// let config = get_config() -// -// // Ban expired. -// if now.timestamp() - timestamp.timestamp() > config.general.ban_time { -// drop(guard) -// warn!("Unbanning {:?}", address) -// guard = self.banlist.write() -// guard[address.shard].remove(address) -// false -// } else { -// debug!("{:?} is banned", address) -// true -// } -// } -// -// None => { -// debug!("{:?} is ok", address) -// false -// } -// } -// } -// -// /// Get the number of configured shards. -// func (c *ConnectionPool) shards(&self) usize { -// self.databases.len() -// } -// -// /// Get the number of servers (primary and replicas) -// /// configured for a shard. -// func (c *ConnectionPool) servers(&self, shard: usize) usize { -// self.addresses[shard].len() -// } -// -// /// Get the total number of servers (databases) we are connected to. -// func (c *ConnectionPool) databases(&self) usize { -// databases = 0 -// for shard in 0..self.shards() { -// databases += self.servers(shard) -// } -// databases -// } -// -// /// Get pool state for a particular shard server as reported by bb8. -// func (c *ConnectionPool) pool_state(&self, shard: usize, server: usize) bb8::State { -// self.databases[shard][server].state() -// } -// -// /// Get the address information for a shard server. -// func (c *ConnectionPool) address(&self, shard: usize, server: usize) &Address { -// &self.addresses[shard][server] -// } -// -// func (c *ConnectionPool) server_info(&self) BytesMut { -// self.server_info.clone() -// } diff --git a/lib/gat/error.go b/lib/gat/protocol/pg_error/error.go similarity index 99% rename from lib/gat/error.go rename to lib/gat/protocol/pg_error/error.go index f8d50f2d..598d4a3d 100644 --- a/lib/gat/error.go +++ b/lib/gat/protocol/pg_error/error.go @@ -1,4 +1,4 @@ -package gat +package pg_error import ( "fmt" @@ -11,7 +11,7 @@ import ( type Severity string const ( - Error Severity = "ERROR" + Err Severity = "ERROR" Fatal = "FATAL" Panic = "PANIC" Warn = "WARNING" @@ -285,7 +285,7 @@ const ( IndexCorrupted = "XX002" ) -type PostgresError struct { +type Error struct { Severity Severity Code Code Message string @@ -305,7 +305,7 @@ type PostgresError struct { Routine string } -func (E *PostgresError) Read(pkt *protocol.ErrorResponse) { +func (E *Error) Read(pkt *protocol.ErrorResponse) { for _, field := range pkt.Fields.Responses { switch field.Code { case byte('S'): @@ -346,7 +346,7 @@ func (E *PostgresError) Read(pkt *protocol.ErrorResponse) { } } -func (E *PostgresError) Packet() *protocol.ErrorResponse { +func (E *Error) Packet() *protocol.ErrorResponse { var fields []protocol.FieldsErrorResponseResponses fields = append(fields, protocol.FieldsErrorResponseResponses{ Code: byte('S'), @@ -448,6 +448,6 @@ func (E *PostgresError) Packet() *protocol.ErrorResponse { return pkt } -func (E *PostgresError) Error() string { +func (E *Error) Error() string { return fmt.Sprintf("%s: %s", E.Severity, E.Message) } -- GitLab