diff --git a/lib/gat/admin.go b/lib/gat/admin.go new file mode 100644 index 0000000000000000000000000000000000000000..f851ee73b5d51da4f0476344a4b1ea239169b596 --- /dev/null +++ b/lib/gat/admin.go @@ -0,0 +1 @@ +package gat diff --git a/lib/gat/messages.go b/lib/gat/messages.go new file mode 100644 index 0000000000000000000000000000000000000000..3a8f8688ed2a9d19fcb517ba46d3d46664b12116 --- /dev/null +++ b/lib/gat/messages.go @@ -0,0 +1,490 @@ +package gat + +// TODO: decide which of these we need and don't need. +// impelement the ones we need + +/// Tell the client that authentication handshake completed successfully. +//pub async fn auth_ok<S>(stream: &mut S) -> Result<(), Error> +//where +// S: tokio::io::AsyncWrite + std::marker::Unpin, +//{ +// let mut auth_ok = BytesMut::with_capacity(9); +// +// auth_ok.put_u8(b'R'); +// auth_ok.put_i32(8); +// auth_ok.put_i32(0); +// +// Ok(write_all(stream, auth_ok).await?) +//} +// +///// Generate md5 password challenge. +//pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error> +//where +// S: tokio::io::AsyncWrite + std::marker::Unpin, +//{ +// // let mut rng = rand::thread_rng(); +// let salt: [u8; 4] = [ +// rand::random(), +// rand::random(), +// rand::random(), +// rand::random(), +// ]; +// +// let mut res = BytesMut::new(); +// res.put_u8(b'R'); +// res.put_i32(12); +// res.put_i32(5); // MD5 +// res.put_slice(&salt[..]); +// +// write_all(stream, res).await?; +// Ok(salt) +//} +// +///// Give the client the process_id and secret we generated +///// used in query cancellation. +//pub async fn backend_key_data<S>( +// stream: &mut S, +// backend_id: i32, +// secret_key: i32, +//) -> Result<(), Error> +//where +// S: tokio::io::AsyncWrite + std::marker::Unpin, +//{ +// let mut key_data = BytesMut::from(&b"K"[..]); +// key_data.put_i32(12); +// key_data.put_i32(backend_id); +// key_data.put_i32(secret_key); +// +// Ok(write_all(stream, key_data).await?) +//} +// +///// Construct a `Q`: Query message. +//pub fn simple_query(query: &str) -> BytesMut { +// let mut res = BytesMut::from(&b"Q"[..]); +// let query = format!("{}\0", query); +// +// res.put_i32(query.len() as i32 + 4); +// res.put_slice(&query.as_bytes()); +// +// res +//} +// +///// Tell the client we're ready for another query. +//pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error> +//where +// S: tokio::io::AsyncWrite + std::marker::Unpin, +//{ +// let mut bytes = BytesMut::with_capacity( +// mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(), +// ); +// +// bytes.put_u8(b'Z'); +// bytes.put_i32(5); +// bytes.put_u8(b'I'); // Idle +// +// Ok(write_all(stream, bytes).await?) +//} +// +///// Send the startup packet the server. We're pretending we're a Pg client. +///// This tells the server which user we are and what database we want. +//pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> { +// let mut bytes = BytesMut::with_capacity(25); +// +// bytes.put_i32(196608); // Protocol number +// +// // User +// bytes.put(&b"user\0"[..]); +// bytes.put_slice(&user.as_bytes()); +// bytes.put_u8(0); +// +// // Database +// bytes.put(&b"database\0"[..]); +// bytes.put_slice(&database.as_bytes()); +// bytes.put_u8(0); +// bytes.put_u8(0); // Null terminator +// +// let len = bytes.len() as i32 + 4i32; +// +// let mut startup = BytesMut::with_capacity(len as usize); +// +// startup.put_i32(len); +// startup.put(bytes); +// +// match stream.write_all(&startup).await { +// Ok(_) => Ok(()), +// Err(_) => return Err(Error::SocketError), +// } +//} +// +///// Parse the params the server sends as a key/value format. +//pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> { +// let mut result = HashMap::new(); +// let mut buf = Vec::new(); +// let mut tmp = String::new(); +// +// while bytes.has_remaining() { +// let mut c = bytes.get_u8(); +// +// // Null-terminated C-strings. +// while c != 0 { +// tmp.push(c as char); +// c = bytes.get_u8(); +// } +// +// if tmp.len() > 0 { +// buf.push(tmp.clone()); +// tmp.clear(); +// } +// } +// +// // Expect pairs of name and value +// // and at least one pair to be present. +// if buf.len() % 2 != 0 || buf.len() < 2 { +// return Err(Error::ClientBadStartup); +// } +// +// let mut i = 0; +// while i < buf.len() { +// let name = buf[i].clone(); +// let value = buf[i + 1].clone(); +// let _ = result.insert(name, value); +// i += 2; +// } +// +// Ok(result) +//} +// +///// Parse StartupMessage parameters. +///// e.g. user, database, application_name, etc. +//pub fn parse_startup(bytes: BytesMut) -> Result<HashMap<String, String>, Error> { +// let result = parse_params(bytes)?; +// +// // Minimum required parameters +// // I want to have the user at the very minimum, according to the protocol spec. +// if !result.contains_key("user") { +// return Err(Error::ClientBadStartup); +// } +// +// Ok(result) +//} +// +///// Create md5 password hash given a salt. +//pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> { +// let mut md5 = Md5::new(); +// +// // First pass +// md5.update(&password.as_bytes()); +// md5.update(&user.as_bytes()); +// +// let output = md5.finalize_reset(); +// +// // Second pass +// md5.update(format!("{:x}", output)); +// md5.update(salt); +// +// let mut password = format!("md5{:x}", md5.finalize()) +// .chars() +// .map(|x| x as u8) +// .collect::<Vec<u8>>(); +// password.push(0); +// +// password +//} +// +///// Send password challenge response to the server. +///// This is the MD5 challenge. +//pub async fn md5_password<S>( +// stream: &mut S, +// user: &str, +// password: &str, +// salt: &[u8], +//) -> Result<(), Error> +//where +// S: tokio::io::AsyncWrite + std::marker::Unpin, +//{ +// let password = md5_hash_password(user, password, salt); +// +// let mut message = BytesMut::with_capacity(password.len() as usize + 5); +// +// message.put_u8(b'p'); +// message.put_i32(password.len() as i32 + 4); +// message.put_slice(&password[..]); +// +// Ok(write_all(stream, message).await?) +//} +// +///// Implements a response to our custom `SET SHARDING KEY` +///// and `SET SERVER ROLE` commands. +///// This tells the client we're ready for the next query. +//pub async fn custom_protocol_response_ok<S>(stream: &mut S, message: &str) -> Result<(), Error> +//where +// S: tokio::io::AsyncWrite + std::marker::Unpin, +//{ +// let mut res = BytesMut::with_capacity(25); +// +// let set_complete = BytesMut::from(&format!("{}\0", message)[..]); +// let len = (set_complete.len() + 4) as i32; +// +// // CommandComplete +// res.put_u8(b'C'); +// res.put_i32(len); +// res.put_slice(&set_complete[..]); +// +// write_all_half(stream, res).await?; +// ready_for_query(stream).await +//} +// +///// Send a custom error message to the client. +///// Tell the client we are ready for the next query and no rollback is necessary. +///// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>. +//pub async fn error_response<S>(stream: &mut S, message: &str) -> Result<(), Error> +//where +// S: tokio::io::AsyncWrite + std::marker::Unpin, +//{ +// error_response_terminal(stream, message).await?; +// ready_for_query(stream).await +//} +// +///// Send a custom error message to the client. +///// Tell the client we are ready for the next query and no rollback is necessary. +///// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>. +//pub async fn error_response_terminal<S>(stream: &mut S, message: &str) -> Result<(), Error> +//where +// S: tokio::io::AsyncWrite + std::marker::Unpin, +//{ +// let mut error = BytesMut::new(); +// +// // Error level +// error.put_u8(b'S'); +// error.put_slice(&b"FATAL\0"[..]); +// +// // Error level (non-translatable) +// error.put_u8(b'V'); +// error.put_slice(&b"FATAL\0"[..]); +// +// // Error code: not sure how much this matters. +// error.put_u8(b'C'); +// error.put_slice(&b"58000\0"[..]); // system_error, see Appendix A. +// +// // The short error message. +// error.put_u8(b'M'); +// error.put_slice(&format!("{}\0", message).as_bytes()); +// +// // No more fields follow. +// error.put_u8(0); +// +// // Compose the two message reply. +// let mut res = BytesMut::with_capacity(error.len() + 5); +// +// res.put_u8(b'E'); +// res.put_i32(error.len() as i32 + 4); +// res.put(error); +// +// Ok(write_all_half(stream, res).await?) +//} +// +//pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error> +//where +// S: tokio::io::AsyncWrite + std::marker::Unpin, +//{ +// let mut error = BytesMut::new(); +// +// // Error level +// error.put_u8(b'S'); +// error.put_slice(&b"FATAL\0"[..]); +// +// // Error level (non-translatable) +// error.put_u8(b'V'); +// error.put_slice(&b"FATAL\0"[..]); +// +// // Error code: not sure how much this matters. +// error.put_u8(b'C'); +// error.put_slice(&b"28P01\0"[..]); // system_error, see Appendix A. +// +// // The short error message. +// error.put_u8(b'M'); +// error.put_slice(&format!("password authentication failed for user \"{}\"\0", user).as_bytes()); +// +// // No more fields follow. +// error.put_u8(0); +// +// // Compose the two message reply. +// let mut res = BytesMut::new(); +// +// res.put_u8(b'E'); +// res.put_i32(error.len() as i32 + 4); +// +// res.put(error); +// +// write_all(stream, res).await +//} +// +///// Respond to a SHOW SHARD command. +//pub async fn show_response<S>(stream: &mut S, name: &str, value: &str) -> Result<(), Error> +//where +// S: tokio::io::AsyncWrite + std::marker::Unpin, +//{ +// // A SELECT response consists of: +// // 1. RowDescription +// // 2. One or more DataRow +// // 3. CommandComplete +// // 4. ReadyForQuery +// +// // The final messages sent to the client +// let mut res = BytesMut::new(); +// +// // RowDescription +// res.put(row_description(&vec![(name, DataType::Text)])); +// +// // DataRow +// res.put(data_row(&vec![value.to_string()])); +// +// // CommandComplete +// res.put(command_complete("SELECT 1")); +// +// write_all_half(stream, res).await?; +// ready_for_query(stream).await +//} +// +//pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut { +// let mut res = BytesMut::new(); +// let mut row_desc = BytesMut::new(); +// +// // how many colums we are storing +// row_desc.put_i16(columns.len() as i16); +// +// for (name, data_type) in columns { +// // Column name +// row_desc.put_slice(&format!("{}\0", name).as_bytes()); +// +// // Doesn't belong to any table +// row_desc.put_i32(0); +// +// // Doesn't belong to any table +// row_desc.put_i16(0); +// +// // Text +// row_desc.put_i32(data_type.into()); +// +// // Text size = variable (-1) +// let type_size = match data_type { +// DataType::Text => -1, +// DataType::Int4 => 4, +// DataType::Numeric => -1, +// }; +// +// row_desc.put_i16(type_size); +// +// // Type modifier: none that I know +// row_desc.put_i32(-1); +// +// // Format being used: text (0), binary (1) +// row_desc.put_i16(0); +// } +// +// res.put_u8(b'T'); +// res.put_i32(row_desc.len() as i32 + 4); +// res.put(row_desc); +// +// res +//} +// +///// Create a DataRow message. +//pub fn data_row(row: &Vec<String>) -> BytesMut { +// let mut res = BytesMut::new(); +// let mut data_row = BytesMut::new(); +// +// data_row.put_i16(row.len() as i16); +// +// for column in row { +// let column = column.as_bytes(); +// data_row.put_i32(column.len() as i32); +// data_row.put_slice(&column); +// } +// +// res.put_u8(b'D'); +// res.put_i32(data_row.len() as i32 + 4); +// res.put(data_row); +// +// res +//} +// +///// Create a CommandComplete message. +//pub fn command_complete(command: &str) -> BytesMut { +// let cmd = BytesMut::from(format!("{}\0", command).as_bytes()); +// let mut res = BytesMut::new(); +// res.put_u8(b'C'); +// res.put_i32(cmd.len() as i32 + 4); +// res.put(cmd); +// res +//} +// +///// Write all data in the buffer to the TcpStream. +//pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error> +//where +// S: tokio::io::AsyncWrite + std::marker::Unpin, +//{ +// match stream.write_all(&buf).await { +// Ok(_) => Ok(()), +// Err(_) => return Err(Error::SocketError), +// } +//} +// +///// Write all the data in the buffer to the TcpStream, write owned half (see mpsc). +//pub async fn write_all_half<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error> +//where +// S: tokio::io::AsyncWrite + std::marker::Unpin, +//{ +// match stream.write_all(&buf).await { +// Ok(_) => Ok(()), +// Err(_) => return Err(Error::SocketError), +// } +//} +// +///// Read a complete message from the socket. +//pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error> +//where +// S: tokio::io::AsyncRead + std::marker::Unpin, +//{ +// let code = match stream.read_u8().await { +// Ok(code) => code, +// Err(_) => return Err(Error::SocketError), +// }; +// +// let len = match stream.read_i32().await { +// Ok(len) => len, +// Err(_) => return Err(Error::SocketError), +// }; +// +// let mut buf = vec![0u8; len as usize - 4]; +// +// match stream.read_exact(&mut buf).await { +// Ok(_) => (), +// Err(_) => return Err(Error::SocketError), +// }; +// +// let mut bytes = BytesMut::with_capacity(len as usize + 1); +// +// bytes.put_u8(code); +// bytes.put_i32(len); +// bytes.put_slice(&buf); +// +// Ok(bytes) +//} +// +//pub fn server_paramater_message(key: &str, value: &str) -> BytesMut { +// let mut server_info = BytesMut::new(); +// +// let null_byte_size = 1; +// let len: usize = +// mem::size_of::<i32>() + key.len() + null_byte_size + value.len() + null_byte_size; +// +// server_info.put_slice("S".as_bytes()); +// server_info.put_i32(len.try_into().unwrap()); +// server_info.put_slice(key.as_bytes()); +// server_info.put_bytes(0, 1); +// server_info.put_slice(value.as_bytes()); +// server_info.put_bytes(0, 1); +// +// return server_info; +//} diff --git a/lib/gat/pool.go b/lib/gat/pool.go new file mode 100644 index 0000000000000000000000000000000000000000..6fc532688a7526df69f02661906fa3dd5cf267f0 --- /dev/null +++ b/lib/gat/pool.go @@ -0,0 +1,566 @@ +package gat + +import ( + "gfx.cafe/gfx/pggat/lib/config" +) + +type PoolSettings struct { + /// Transaction or Session. + pool_mode config.PoolMode + + // Number of shards. + shards int + + // Connecting user. + user config.User + + // Default server role to connect to. + default_role config.ServerRole + + // Enable/disable query parser. + query_parser_enabled bool + + // Read from the primary as well or not. + primary_reads_enabled bool + + // Sharding function. + sharding_function ShardFunc +} + +func DefaultPool() PoolSettings { + return PoolSettings{ + pool_mode: config.POOLMODE_TXN, + shards: 1, + user: config.User{ + Name: "postgres", + Password: "test", + }, + default_role: config.SERVERROLE_NONE, + query_parser_enabled: false, + primary_reads_enabled: true, + //TODO: select default sharding function + sharding_function: nil, + } +} + +type ServerPool struct { + address string + user config.User + database string + client_server_map map[ClientKey]ClientInfo + stats any //TODO: stats +} + +//TODO: implement server pool +//#[async_trait] +//impl ManageConnection for ServerPool { +// type Connection = Server; +// type Error = Error; +// +// /// Attempts to create a new connection. +// async fn connect(&self) -> Result<Self::Connection, Self::Error> { +// info!( +// "Creating a new connection to {:?} using user {:?}", +// self.address.name(), +// self.user.username +// ); +// +// // Put a temporary process_id into the stats +// // for server login. +// let process_id = rand::random::<i32>(); +// self.stats.server_login(process_id, self.address.id); +// +// // Connect to the PostgreSQL server. +// match Server::startup( +// &self.address, +// &self.user, +// &self.database, +// self.client_server_map.clone(), +// self.stats.clone(), +// ) +// .await +// { +// Ok(conn) => { +// // Remove the temporary process_id from the stats. +// self.stats.server_disconnecting(process_id, self.address.id); +// Ok(conn) +// } +// Err(err) => { +// // Remove the temporary process_id from the stats. +// self.stats.server_disconnecting(process_id, self.address.id); +// Err(err) +// } +// } +// } +// +// /// Determines if the connection is still connected to the database. +// async fn is_valid(&self, _conn: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> { +// Ok(()) +// } +// +// /// Synchronously determine if the connection is no longer usable, if possible. +// fn has_broken(&self, conn: &mut Self::Connection) -> bool { +// conn.is_bad() +// } +//} +// +///// Get the connection pool +//pub fn get_pool(db: String, user: String) -> Option<ConnectionPool> { +// match get_all_pools().get(&(db, user)) { +// Some(pool) => Some(pool.clone()), +// None => None, +// } +//} +// +///// How many total servers we have in the config. +//pub fn get_number_of_addresses() -> usize { +// get_all_pools() +// .iter() +// .map(|(_, pool)| pool.databases()) +// .sum() +//} +// +///// Get a pointer to all configured pools. +//pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> { +// return (*(*POOLS.load())).clone(); +//} + +type ConnectionPool struct { + /// The pools handled internally by bb8. + // TODO: https://docs.rs/bb8-bolt/latest/src/bb8_bolt/lib.rs.html + databases [][]ServerPool + + /// The addresses (host, port, role) to handle + /// failover and load balancing deterministically. + addresses []config.Server + + /// List of banned addresses (see above) + /// that should not be queried. + banlist []string + + /// The statistics aggregator runs in a separate task + /// and receives stats from clients, servers, and the pool. + stats any // TODO: stats + + /// The server information (K messages) have to be passed to the + /// clients on startup. We pre-connect to all shards and replicas + /// on pool creation and save the K messages here. + /// TODO: consider storing this in a better format + server_info []byte + + /// Pool configuration. + settings PoolSettings +} + +//TODO: implement this +// /// Construct the connection pool from the configuration. +// func (c *ConnectionPool) from_config(client_server_map: ClientServerMap) Result<(), Error> { +// let config = get_config() +// +// new_pools = HashMap::new() +// address_id = 0 +// +// for (pool_name, pool_config) in &config.pools { +// // There is one pool per database/user pair. +// for (_, user) in &pool_config.users { +// shards = Vec::new() +// addresses = Vec::new() +// banlist = Vec::new() +// shard_ids = pool_config +// .shards +// .clone() +// .into_keys() +// .map(|x| x.to_string()) +// .collect::<Vec<string>>() +// +// // Sort by shard number to ensure consistency. +// shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap()) +// +// for shard_idx in &shard_ids { +// let shard = &pool_config.shards[shard_idx] +// pools = Vec::new() +// servers = Vec::new() +// address_index = 0 +// replica_number = 0 +// +// for server in shard.servers.iter() { +// let role = match server.2.as_ref() { +// "primary" => Role::Primary, +// "replica" => Role::Replica, +// _ => { +// error!("Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2) +// Role::Replica +// } +// } +// +// let address = Address { +// id: address_id, +// database: shard.database.clone(), +// host: server.0.clone(), +// port: server.1 as u16, +// role: role, +// address_index, +// replica_number, +// shard: shard_idx.parse::<usize>().unwrap(), +// username: user.username.clone(), +// pool_name: pool_name.clone(), +// } +// +// address_id += 1 +// address_index += 1 +// +// if role == Role::Replica { +// replica_number += 1 +// } +// +// let manager = ServerPool::new( +// address.clone(), +// user.clone(), +// &shard.database, +// client_server_map.clone(), +// get_reporter(), +// ) +// +// let pool = Pool::builder() +// .max_size(user.pool_size) +// .connection_timeout(std::time::Duration::from_millis( +// config.general.connect_timeout, +// )) +// .test_on_check_out(false) +// .build(manager) +// .await +// .unwrap() +// +// pools.push(pool) +// servers.push(address) +// } +// +// shards.push(pools) +// addresses.push(servers) +// banlist.push(HashMap::new()) +// } +// +// assert_eq!(shards.len(), addresses.len()) +// +// pool = ConnectionPool { +// databases: shards, +// addresses: addresses, +// banlist: Arc::new(RwLock::new(banlist)), +// stats: get_reporter(), +// server_info: BytesMut::new(), +// settings: PoolSettings { +// pool_mode: match pool_config.pool_mode.as_str() { +// "transaction" => PoolMode::Transaction, +// "session" => PoolMode::Session, +// _ => unreachable!(), +// }, +// // shards: pool_config.shards.clone(), +// shards: shard_ids.len(), +// user: user.clone(), +// default_role: match pool_config.default_role.as_str() { +// "any" => None, +// "replica" => Some(Role::Replica), +// "primary" => Some(Role::Primary), +// _ => unreachable!(), +// }, +// query_parser_enabled: pool_config.query_parser_enabled.clone(), +// primary_reads_enabled: pool_config.primary_reads_enabled, +// sharding_function: match pool_config.sharding_function.as_str() { +// "pg_bigint_hash" => ShardingFunction::PgBigintHash, +// "sha1" => ShardingFunction::Sha1, +// _ => unreachable!(), +// }, +// }, +// } +// +// // Connect to the servers to make sure pool configuration is valid +// // before setting it globally. +// match pool.validate().await { +// Ok(_) => (), +// Err(err) => { +// error!("Could not validate connection pool: {:?}", err) +// return Err(err) +// } +// } +// +// // There is one pool per database/user pair. +// new_pools.insert((pool_name.clone(), user.username.clone()), pool) +// } +// } +// +// POOLS.store(Arc::new(new_pools.clone())) +// +// Ok(()) +// } +// +// /// Connect to all shards and grab server information. +// /// Return server information we will pass to the clients +// /// when they connect. +// /// This also warms up the pool for clients that connect when +// /// the pooler starts up. +// async fn validate(&mut self) Result<(), Error> { +// server_infos = Vec::new() +// for shard in 0..self.shards() { +// for server in 0..self.servers(shard) { +// let connection = match self.databases[shard][server].get().await { +// Ok(conn) => conn, +// Err(err) => { +// error!("Shard {} down or misconfigured: {:?}", shard, err) +// continue +// } +// } +// +// let proxy = connection +// let server = &*proxy +// let server_info = server.server_info() +// +// if server_infos.len() > 0 { +// // Compare against the last server checked. +// if server_info != server_infos[server_infos.len() - 1] { +// warn!( +// "{:?} has different server configuration than the last server", +// proxy.address() +// ) +// } +// } +// +// server_infos.push(server_info) +// } +// } +// +// // TODO: compare server information to make sure +// // all shards are running identical configurations. +// if server_infos.len() == 0 { +// return Err(Error::AllServersDown) +// } +// +// // We're assuming all servers are identical. +// // TODO: not true. +// self.server_info = server_infos[0].clone() +// +// Ok(()) +// } +// +// /// Get a connection from the pool. +// func (c *ConnectionPool) get( +// &self, +// shard: usize, // shard number +// role: Option<Role>, // primary or replica +// process_id: i32, // client id +// ) Result<(PooledConnection<'_, ServerPool>, Address), Error> { +// let now = Instant::now() +// candidates: Vec<&Address> = self.addresses[shard] +// .iter() +// .filter(|address| address.role == role) +// .collect() +// +// // Random load balancing +// candidates.shuffle(&mut thread_rng()) +// +// let healthcheck_timeout = get_config().general.healthcheck_timeout +// let healthcheck_delay = get_config().general.healthcheck_delay as u128 +// +// while !candidates.is_empty() { +// // Get the next candidate +// let address = match candidates.pop() { +// Some(address) => address, +// None => break, +// } +// +// if self.is_banned(&address, role) { +// debug!("Address {:?} is banned", address) +// continue +// } +// +// // Indicate we're waiting on a server connection from a pool. +// self.stats.client_waiting(process_id, address.id) +// +// // Check if we can connect +// conn = match self.databases[address.shard][address.address_index] +// .get() +// .await +// { +// Ok(conn) => conn, +// Err(err) => { +// error!("Banning instance {:?}, error: {:?}", address, err) +// self.ban(&address, process_id) +// self.stats +// .checkout_time(now.elapsed().as_micros(), process_id, address.id) +// continue +// } +// } +// +// // // Check if this server is alive with a health check. +// let server = &mut *conn +// +// // Will return error if timestamp is greater than current system time, which it should never be set to +// let require_healthcheck = +// server.last_activity().elapsed().unwrap().as_millis() > healthcheck_delay +// +// // Do not issue a health check unless it's been a little while +// // since we last checked the server is ok. +// // Health checks are pretty expensive. +// if !require_healthcheck { +// self.stats +// .checkout_time(now.elapsed().as_micros(), process_id, address.id) +// self.stats.server_active(conn.process_id(), address.id) +// return Ok((conn, address.clone())) +// } +// +// debug!("Running health check on server {:?}", address) +// +// self.stats.server_tested(server.process_id(), address.id) +// +// match tokio::time::timeout( +// tokio::time::Duration::from_millis(healthcheck_timeout), +// server.query(""), // Cheap query (query parser not used in PG) +// ) +// .await +// { +// // Check if health check succeeded. +// Ok(res) => match res { +// Ok(_) => { +// self.stats +// .checkout_time(now.elapsed().as_micros(), process_id, address.id) +// self.stats.server_active(conn.process_id(), address.id) +// return Ok((conn, address.clone())) +// } +// +// // Health check failed. +// Err(err) => { +// error!( +// "Banning instance {:?} because of failed health check, {:?}", +// address, err +// ) +// +// // Don't leave a bad connection in the pool. +// server.mark_bad() +// +// self.ban(&address, process_id) +// continue +// } +// }, +// +// // Health check timed out. +// Err(err) => { +// error!( +// "Banning instance {:?} because of health check timeout, {:?}", +// address, err +// ) +// // Don't leave a bad connection in the pool. +// server.mark_bad() +// +// self.ban(&address, process_id) +// continue +// } +// } +// } +// +// Err(Error::AllServersDown) +// } +// +// /// Ban an address (i.e. replica). It no longer will serve +// /// traffic for any new transactions. Existing transactions on that replica +// /// will finish successfully or error out to the clients. +// func (c *ConnectionPool) ban(&self, address: &Address, process_id: i32) { +// self.stats.client_disconnecting(process_id, address.id) +// +// error!("Banning {:?}", address) +// +// let now = chrono::offset::Utc::now().naive_utc() +// guard = self.banlist.write() +// guard[address.shard].insert(address.clone(), now) +// } +// +// /// Clear the replica to receive traffic again. Takes effect immediately +// /// for all new transactions. +// func (c *ConnectionPool) _unban(&self, address: &Address) { +// guard = self.banlist.write() +// guard[address.shard].remove(address) +// } +// +// /// Check if a replica can serve traffic. If all replicas are banned, +// /// we unban all of them. Better to try then not to. +// func (c *ConnectionPool) is_banned(&self, address: &Address, role: Option<Role>) bool { +// let replicas_available = match role { +// Some(Role::Replica) => self.addresses[address.shard] +// .iter() +// .filter(|addr| addr.role == Role::Replica) +// .count(), +// None => self.addresses[address.shard].len(), +// Some(Role::Primary) => return false, // Primary cannot be banned. +// } +// +// debug!("Available targets for {:?}: {}", role, replicas_available) +// +// let guard = self.banlist.read() +// +// // Everything is banned = nothing is banned. +// if guard[address.shard].len() == replicas_available { +// drop(guard) +// guard = self.banlist.write() +// guard[address.shard].clear() +// drop(guard) +// warn!("Unbanning all replicas.") +// return false +// } +// +// // I expect this to miss 99.9999% of the time. +// match guard[address.shard].get(address) { +// Some(timestamp) => { +// let now = chrono::offset::Utc::now().naive_utc() +// let config = get_config() +// +// // Ban expired. +// if now.timestamp() - timestamp.timestamp() > config.general.ban_time { +// drop(guard) +// warn!("Unbanning {:?}", address) +// guard = self.banlist.write() +// guard[address.shard].remove(address) +// false +// } else { +// debug!("{:?} is banned", address) +// true +// } +// } +// +// None => { +// debug!("{:?} is ok", address) +// false +// } +// } +// } +// +// /// Get the number of configured shards. +// func (c *ConnectionPool) shards(&self) usize { +// self.databases.len() +// } +// +// /// Get the number of servers (primary and replicas) +// /// configured for a shard. +// func (c *ConnectionPool) servers(&self, shard: usize) usize { +// self.addresses[shard].len() +// } +// +// /// Get the total number of servers (databases) we are connected to. +// func (c *ConnectionPool) databases(&self) usize { +// databases = 0 +// for shard in 0..self.shards() { +// databases += self.servers(shard) +// } +// databases +// } +// +// /// Get pool state for a particular shard server as reported by bb8. +// func (c *ConnectionPool) pool_state(&self, shard: usize, server: usize) bb8::State { +// self.databases[shard][server].state() +// } +// +// /// Get the address information for a shard server. +// func (c *ConnectionPool) address(&self, shard: usize, server: usize) &Address { +// &self.addresses[shard][server] +// } +// +// func (c *ConnectionPool) server_info(&self) BytesMut { +// self.server_info.clone() +// } diff --git a/lib/gat/query_router.go b/lib/gat/query_router.go new file mode 100644 index 0000000000000000000000000000000000000000..9aaa5b3b5983ee31995b3cf5df5a6a678a990db2 --- /dev/null +++ b/lib/gat/query_router.go @@ -0,0 +1,293 @@ +package gat + +import ( + "log" + "regexp" + + "gfx.cafe/gfx/pggat/lib/config" +) + +var compiler = regexp.MustCompile + +var CustomSqlRegex = []*regexp.Regexp{ + compiler("(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$"), + compiler("(?i)^ *SET SHARD TO '?([0-9]+|ANY)'? *;? *$"), + compiler("(?i)^ *SHOW SHARD *;? *$"), + compiler("(?i)^ *SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)' *;? *$"), + compiler("(?i)^ *SHOW SERVER ROLE *;? *$"), + compiler("(?i)^ *SET PRIMARY READS TO '?(on|off|default)'? *;? *$"), + compiler("(?i)^ *SHOW PRIMARY READS *;? *$"), +} + +type Command interface { +} + +var _ []Command = []Command{ + &CommandSetShardingKey{}, + &CommandSetShard{}, + &CommandShowShard{}, + &CommandSetServerRole{}, + &CommandShowServerRole{}, + &CommandSetPrimaryReads{}, + &CommandShowPrimaryReads{}, +} + +type CommandSetShardingKey struct{} +type CommandSetShard struct{} +type CommandShowShard struct{} +type CommandSetServerRole struct{} +type CommandShowServerRole struct{} +type CommandSetPrimaryReads struct{} +type CommandShowPrimaryReads struct{} + +type QueryRouter struct { + active_shard int + active_role config.ServerRole + query_parser_enabled bool + primary_reads_enabled bool + pool_settings PoolSettings +} + +// / Pool settings can change because of a config reload. +func (r *QueryRouter) UpdatePoolSettings(pool_settings PoolSettings) { + r.pool_settings = pool_settings +} + +// / Try to parse a command and execute it. +func (r *QueryRouter) try_execute_command(buf []byte) (Command, string) { + // Only simple protocol supported for commands. + if buf[0] != 'Q' { + return nil, "" + } + msglen := 0 + // TODO: read msg len + // msglen := buf.get_i32() + custom := false + for _, v := range CustomSqlRegex { + if v.Match(buf[:msglen-5]) { + custom = true + break + } + } + // This is not a custom query, try to infer which + // server it'll go to if the query parser is enabled. + if !custom { + log.Println("Regular query, not a command") + return nil, "" + } + + // TODO: command matching + //command := switch matches[0] { + // 0 => Command::SetShardingKey, + // 1 => Command::SetShard, + // 2 => Command::ShowShard, + // 3 => Command::SetServerRole, + // 4 => Command::ShowServerRole, + // 5 => Command::SetPrimaryReads, + // 6 => Command::ShowPrimaryReads, + // _ => unreachable!(), + //} + + //mut value := switch command { + // Command::SetShardingKey + // | Command::SetShard + // | Command::SetServerRole + // | Command::SetPrimaryReads => { + // // Capture value. I know this re-runs the regex engine, but I haven't + // // figured out a better way just yet. I think I can write a single Regex + // // that switches all 5 custom SQL patterns, but maybe that's not very legible? + // // + // // I think this is faster than running the Regex engine 5 times. + // switch regex_list[matches[0]].captures(&query) { + // Some(captures) => switch captures.get(1) { + // Some(value) => value.as_str().to_string(), + // None => return None, + // }, + // None => return None, + // } + // } + + // Command::ShowShard => self.shard().to_string(), + // Command::ShowServerRole => switch self.active_role { + // Some(Role::Primary) => string("primary"), + // Some(Role::Replica) => string("replica"), + // None => { + // if self.query_parser_enabled { + // string("auto") + // } else { + // string("any") + // } + // } + // }, + + // Command::ShowPrimaryReads => switch self.primary_reads_enabled { + // true => string("on"), + // false => string("off"), + // }, + //} + + //switch command { + // Command::SetShardingKey => { + // sharder := Sharder::new( + // self.pool_settings.shards, + // self.pool_settings.sharding_function, + // ) + // shard := sharder.shard(value.parse::<i64>().unwrap()) + // self.active_shard := Some(shard) + // value := shard.to_string() + // } + + // Command::SetShard => { + // self.active_shard := switch value.to_ascii_uppercase().as_ref() { + // "ANY" => Some(rand::random::<usize>() % self.pool_settings.shards), + // _ => Some(value.parse::<usize>().unwrap()), + // } + // } + + // Command::SetServerRole => { + // self.active_role := switch value.to_ascii_lowercase().as_ref() { + // "primary" => { + // self.query_parser_enabled := false + // Some(Role::Primary) + // } + + // "replica" => { + // self.query_parser_enabled := false + // Some(Role::Replica) + // } + + // "any" => { + // self.query_parser_enabled := false + // None + // } + + // "auto" => { + // self.query_parser_enabled := true + // None + // } + + // "default" => { + // self.active_role := self.pool_settings.default_role + // self.query_parser_enabled := self.query_parser_enabled + // self.active_role + // } + + // _ => unreachable!(), + // } + // } + + // Command::SetPrimaryReads => { + // if value == "on" { + // log.Println("Setting primary reads to on") + // self.primary_reads_enabled := true + // } else if value == "off" { + // log.Println("Setting primary reads to off") + // self.primary_reads_enabled := false + // } else if value == "default" { + // log.Println("Setting primary reads to default") + // self.primary_reads_enabled := self.pool_settings.primary_reads_enabled + // } + // } + + // _ => (), + //} + + //Some((command, value)) + return nil, "" +} + +// / Try to infer which server to connect to based on the contents of the query. +// TODO: implement +func (r *QueryRouter) InferRole(buf []byte) bool { + log.Println("Inferring role") + + //code := buf.get_u8() as char + //len := buf.get_i32() as usize + + //query := switch code { + // // Query + // 'Q' => { + // query := string(&buf[:len - 5]).to_string() + // log.Println("Query: '%v'", query) + // query + // } + + // // Parse (prepared statement) + // 'P' => { + // mut start := 0 + // mut end + + // // Skip the name of the prepared statement. + // while buf[start] != 0 && start < buf.len() { + // start += 1 + // } + // start += 1 // Skip terminating null + + // // Find the end of the prepared stmt (\0) + // end := start + // while buf[end] != 0 && end < buf.len() { + // end += 1 + // } + + // query := string(&buf[start:end]).to_string() + + // log.Println("Prepared statement: '%v'", query) + + // query.replace("$", "") // Remove placeholders turning them into "values" + // } + + // _ => return false, + //} + + //ast := switch Parser::parse_sql(&PostgreSqlDialect %v, &query) { + // Ok(ast) => ast, + // Err(err) => { + // log.Println("%v", err.to_string()) + // return false + // } + //} + + //if ast.len() == 0 { + // return false + //} + + //switch ast[0] { + // // All transactions go to the primary, probably a write. + // StartTransaction { : } => { + // self.active_role := Some(Role::Primary) + // } + + // // Likely a read-only query + // Query { : } => { + // self.active_role := switch self.primary_reads_enabled { + // false => Some(Role::Replica), // If primary should not be receiving reads, use a replica. + // true => None, // Any server role is fine in this case. + // } + // } + + // // Likely a write + // _ => { + // self.active_role := Some(Role::Primary) + // } + //} + + return true +} + +// / Get the current desired server role we should be talking to. +func (r *QueryRouter) Role() config.ServerRole { + return r.active_role +} + +// / Get desired shard we should be talking to. +func (r *QueryRouter) Shard() int { + return r.active_shard +} + +func (r *QueryRouter) SetShard(shard int) { + r.active_shard = shard +} + +func (r *QueryRouter) QueryParserEnabled() bool { + return r.query_parser_enabled +} diff --git a/lib/gat/query_router_test.go b/lib/gat/query_router_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9f9c0b18eb7ddf6fe52180f382b860de790cace3 --- /dev/null +++ b/lib/gat/query_router_test.go @@ -0,0 +1,302 @@ +package gat + +//TODO: adapt tests +//#[cfg(test)] +//mod test { +// use super::*; +// use crate::messages::simple_query; +// use crate::pool::PoolMode; +// use crate::sharding::ShardingFunction; +// use bytes::BufMut; +// +// #[test] +// fn test_defaults() { +// QueryRouter::setup(); +// let qr = QueryRouter::new(); +// +// assert_eq!(qr.role(), None); +// } +// +// #[test] +// fn test_infer_role_replica() { +// QueryRouter::setup(); +// let mut qr = QueryRouter::new(); +// assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); +// assert_eq!(qr.query_parser_enabled(), true); +// +// assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); +// +// let queries = vec![ +// simple_query("SELECT * FROM items WHERE id = 5"), +// simple_query( +// "SELECT id, name, value FROM items INNER JOIN prices ON item.id = prices.item_id", +// ), +// simple_query("WITH t AS (SELECT * FROM items) SELECT * FROM t"), +// ]; +// +// for query in queries { +// // It's a recognized query +// assert!(qr.infer_role(query)); +// assert_eq!(qr.role(), Some(Role::Replica)); +// } +// } +// +// #[test] +// fn test_infer_role_primary() { +// QueryRouter::setup(); +// let mut qr = QueryRouter::new(); +// +// let queries = vec![ +// simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"), +// simple_query("INSERT INTO items (id, name) VALUES (5, 'pumpkin')"), +// simple_query("DELETE FROM items WHERE id = 5"), +// simple_query("BEGIN"), // Transaction start +// ]; +// +// for query in queries { +// // It's a recognized query +// assert!(qr.infer_role(query)); +// assert_eq!(qr.role(), Some(Role::Primary)); +// } +// } +// +// #[test] +// fn test_infer_role_primary_reads_enabled() { +// QueryRouter::setup(); +// let mut qr = QueryRouter::new(); +// let query = simple_query("SELECT * FROM items WHERE id = 5"); +// assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None); +// +// assert!(qr.infer_role(query)); +// assert_eq!(qr.role(), None); +// } +// +// #[test] +// fn test_infer_role_parse_prepared() { +// QueryRouter::setup(); +// let mut qr = QueryRouter::new(); +// qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); +// assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); +// +// let prepared_stmt = BytesMut::from( +// &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], +// ); +// let mut res = BytesMut::from(&b"P"[..]); +// res.put_i32(prepared_stmt.len() as i32 + 4 + 1 + 2); +// res.put_u8(0); +// res.put(prepared_stmt); +// res.put_i16(0); +// +// assert!(qr.infer_role(res)); +// assert_eq!(qr.role(), Some(Role::Replica)); +// } +// +// #[test] +// fn test_regex_set() { +// QueryRouter::setup(); +// +// let tests = [ +// // Upper case +// "SET SHARDING KEY TO '1'", +// "SET SHARD TO '1'", +// "SHOW SHARD", +// "SET SERVER ROLE TO 'replica'", +// "SET SERVER ROLE TO 'primary'", +// "SET SERVER ROLE TO 'any'", +// "SET SERVER ROLE TO 'auto'", +// "SHOW SERVER ROLE", +// "SET PRIMARY READS TO 'on'", +// "SET PRIMARY READS TO 'off'", +// "SET PRIMARY READS TO 'default'", +// "SHOW PRIMARY READS", +// // Lower case +// "set sharding key to '1'", +// "set shard to '1'", +// "show shard", +// "set server role to 'replica'", +// "set server role to 'primary'", +// "set server role to 'any'", +// "set server role to 'auto'", +// "show server role", +// "set primary reads to 'on'", +// "set primary reads to 'OFF'", +// "set primary reads to 'deFaUlt'", +// // No quotes +// "SET SHARDING KEY TO 11235", +// "SET SHARD TO 15", +// "SET PRIMARY READS TO off", +// // Spaces and semicolon +// " SET SHARDING KEY TO 11235 ; ", +// " SET SHARD TO 15; ", +// " SET SHARDING KEY TO 11235 ;", +// " SET SERVER ROLE TO 'primary'; ", +// " SET SERVER ROLE TO 'primary' ; ", +// " SET SERVER ROLE TO 'primary' ;", +// " SET PRIMARY READS TO 'off' ;", +// ]; +// +// // Which regexes it'll match to in the list +// let matches = [ +// 0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 6, 0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 0, 1, 5, 0, 1, 0, +// 3, 3, 3, 5, +// ]; +// +// let list = CUSTOM_SQL_REGEX_LIST.get().unwrap(); +// let set = CUSTOM_SQL_REGEX_SET.get().unwrap(); +// +// for (i, test) in tests.iter().enumerate() { +// if !list[matches[i]].is_match(test) { +// println!("{} does not match {}", test, list[matches[i]]); +// assert!(false); +// } +// assert_eq!(set.matches(test).into_iter().collect::<Vec<_>>().len(), 1); +// } +// +// let bad = [ +// "SELECT * FROM table", +// "SELECT * FROM table WHERE value = 'set sharding key to 5'", // Don't capture things in the middle of the query +// ]; +// +// for query in &bad { +// assert_eq!(set.matches(query).into_iter().collect::<Vec<_>>().len(), 0); +// } +// } +// +// #[test] +// fn test_try_execute_command() { +// QueryRouter::setup(); +// let mut qr = QueryRouter::new(); +// +// // SetShardingKey +// let query = simple_query("SET SHARDING KEY TO 13"); +// assert_eq!( +// qr.try_execute_command(query), +// Some((Command::SetShardingKey, String::from("0"))) +// ); +// assert_eq!(qr.shard(), 0); +// +// // SetShard +// let query = simple_query("SET SHARD TO '1'"); +// assert_eq!( +// qr.try_execute_command(query), +// Some((Command::SetShard, String::from("1"))) +// ); +// assert_eq!(qr.shard(), 1); +// +// // ShowShard +// let query = simple_query("SHOW SHARD"); +// assert_eq!( +// qr.try_execute_command(query), +// Some((Command::ShowShard, String::from("1"))) +// ); +// +// // SetServerRole +// let roles = ["primary", "replica", "any", "auto", "primary"]; +// let verify_roles = [ +// Some(Role::Primary), +// Some(Role::Replica), +// None, +// None, +// Some(Role::Primary), +// ]; +// let query_parser_enabled = [false, false, false, true, false]; +// +// for (idx, role) in roles.iter().enumerate() { +// let query = simple_query(&format!("SET SERVER ROLE TO '{}'", role)); +// assert_eq!( +// qr.try_execute_command(query), +// Some((Command::SetServerRole, String::from(*role))) +// ); +// assert_eq!(qr.role(), verify_roles[idx],); +// assert_eq!(qr.query_parser_enabled(), query_parser_enabled[idx],); +// +// // ShowServerRole +// let query = simple_query("SHOW SERVER ROLE"); +// assert_eq!( +// qr.try_execute_command(query), +// Some((Command::ShowServerRole, String::from(*role))) +// ); +// } +// +// let primary_reads = ["on", "off", "default"]; +// let primary_reads_enabled = ["on", "off", "on"]; +// +// for (idx, primary_reads) in primary_reads.iter().enumerate() { +// assert_eq!( +// qr.try_execute_command(simple_query(&format!( +// "SET PRIMARY READS TO {}", +// primary_reads +// ))), +// Some((Command::SetPrimaryReads, String::from(*primary_reads))) +// ); +// assert_eq!( +// qr.try_execute_command(simple_query("SHOW PRIMARY READS")), +// Some(( +// Command::ShowPrimaryReads, +// String::from(primary_reads_enabled[idx]) +// )) +// ); +// } +// } +// +// #[test] +// fn test_enable_query_parser() { +// QueryRouter::setup(); +// let mut qr = QueryRouter::new(); +// let query = simple_query("SET SERVER ROLE TO 'auto'"); +// assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); +// +// assert!(qr.try_execute_command(query) != None); +// assert!(qr.query_parser_enabled()); +// assert_eq!(qr.role(), None); +// +// let query = simple_query("INSERT INTO test_table VALUES (1)"); +// assert_eq!(qr.infer_role(query), true); +// assert_eq!(qr.role(), Some(Role::Primary)); +// +// let query = simple_query("SELECT * FROM test_table"); +// assert_eq!(qr.infer_role(query), true); +// assert_eq!(qr.role(), Some(Role::Replica)); +// +// assert!(qr.query_parser_enabled()); +// let query = simple_query("SET SERVER ROLE TO 'default'"); +// assert!(qr.try_execute_command(query) != None); +// assert!(qr.query_parser_enabled()); +// } +// +// #[test] +// fn test_update_from_pool_settings() { +// QueryRouter::setup(); +// +// let pool_settings = PoolSettings { +// pool_mode: PoolMode::Transaction, +// shards: 0, +// user: crate::config::User::default(), +// default_role: Some(Role::Replica), +// query_parser_enabled: true, +// primary_reads_enabled: false, +// sharding_function: ShardingFunction::PgBigintHash, +// }; +// let mut qr = QueryRouter::new(); +// assert_eq!(qr.active_role, None); +// assert_eq!(qr.active_shard, None); +// assert_eq!(qr.query_parser_enabled, false); +// assert_eq!(qr.primary_reads_enabled, false); +// +// // Internal state must not be changed due to this, only defaults +// qr.update_pool_settings(pool_settings.clone()); +// +// assert_eq!(qr.active_role, None); +// assert_eq!(qr.active_shard, None); +// assert_eq!(qr.query_parser_enabled, false); +// assert_eq!(qr.primary_reads_enabled, false); +// +// let q1 = simple_query("SET SERVER ROLE TO 'primary'"); +// assert!(qr.try_execute_command(q1) != None); +// assert_eq!(qr.active_role.unwrap(), Role::Primary); +// +// let q2 = simple_query("SET SERVER ROLE TO 'default'"); +// assert!(qr.try_execute_command(q2) != None); +// assert_eq!(qr.active_role.unwrap(), pool_settings.clone().default_role); +// } +//} diff --git a/lib/gat/server.go b/lib/gat/server.go index 117ea073e2e2e1daa15db66a1bd474c458876be0..6194373fafd22ac56291310619362285f6f0c03f 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -117,7 +117,6 @@ func (s *Server) startup(ctx context.Context) error { return err } buf2.Write(buf.Bytes()) - log.Println(buf2) _, err = s.wr.Write(buf2.Bytes()) if err != nil { return err @@ -143,7 +142,7 @@ func (s *Server) connect(ctx context.Context) error { return err } msglen := int(msglen32) - log.Println(string(code), msglen) + s.log.Debug().Str("code", string(code)).Int("len", msglen).Msg("startup msg") switch code { case 'R': var auth_code int32 diff --git a/lib/gat/sharding.go b/lib/gat/sharding.go new file mode 100644 index 0000000000000000000000000000000000000000..4deb7267fedade1131a5b02456f59b371971fba7 --- /dev/null +++ b/lib/gat/sharding.go @@ -0,0 +1,122 @@ +package gat + +const PARTITION_HASH_SEED = 0x7A5B22367996DCFD + +type ShardFunc func(int64) int + +type Sharder struct { + shards int + fn ShardFunc +} + +func NewSharder(shards int, fn ShardFunc) *Sharder { + return &Sharder{ + shards: shards, + fn: fn, + } +} + +//TODO: implement hash functions +// +// fn pg_bigint_hash(&self, key: i64) -> usize { +// let mut lohalf = key as u32; +// let hihalf = (key >> 32) as u32; +// lohalf ^= if key >= 0 { hihalf } else { !hihalf }; +// Self::combine(0, Self::pg_u32_hash(lohalf)) as usize % self.shards +// } + +// /// Example of a hashing function based on SHA1. +// fn sha1(&self, key: i64) -> usize { +// let mut hasher = Sha1::new(); + +// hasher.update(&key.to_string().as_bytes()); + +// let result = hasher.finalize(); + +// // Convert the SHA1 hash into hex so we can parse it as a large integer. +// let hex = format!("{:x}", result); + +// // Parse the last 8 bytes as an integer (8 bytes = bigint). +// let key = i64::from_str_radix(&hex[hex.len() - 8..], 16).unwrap() as usize; + +// key % self.shards +// } + +// #[inline] +// fn rot(x: u32, k: u32) -> u32 { +// (x << k) | (x >> (32 - k)) +// } + +// #[inline] +// fn mix(mut a: u32, mut b: u32, mut c: u32) -> (u32, u32, u32) { +// a = a.wrapping_sub(c); +// a ^= Self::rot(c, 4); +// c = c.wrapping_add(b); + +// b = b.wrapping_sub(a); +// b ^= Self::rot(a, 6); +// a = a.wrapping_add(c); + +// c = c.wrapping_sub(b); +// c ^= Self::rot(b, 8); +// b = b.wrapping_add(a); + +// a = a.wrapping_sub(c); +// a ^= Self::rot(c, 16); +// c = c.wrapping_add(b); + +// b = b.wrapping_sub(a); +// b ^= Self::rot(a, 19); +// a = a.wrapping_add(c); + +// c = c.wrapping_sub(b); +// c ^= Self::rot(b, 4); +// b = b.wrapping_add(a); + +// (a, b, c) +// } + +// #[inline] +// fn _final(mut a: u32, mut b: u32, mut c: u32) -> (u32, u32, u32) { +// c ^= b; +// c = c.wrapping_sub(Self::rot(b, 14)); +// a ^= c; +// a = a.wrapping_sub(Self::rot(c, 11)); +// b ^= a; +// b = b.wrapping_sub(Self::rot(a, 25)); +// c ^= b; +// c = c.wrapping_sub(Self::rot(b, 16)); +// a ^= c; +// a = a.wrapping_sub(Self::rot(c, 4)); +// b ^= a; +// b = b.wrapping_sub(Self::rot(a, 14)); +// c ^= b; +// c = c.wrapping_sub(Self::rot(b, 24)); +// (a, b, c) +// } + +// #[inline] +// fn combine(mut a: u64, b: u64) -> u64 { +// a ^= b +// .wrapping_add(0x49a0f4dd15e5a8e3 as u64) +// .wrapping_add(a << 54) +// .wrapping_add(a >> 7); +// a +// } + +// #[inline] +// fn pg_u32_hash(k: u32) -> u64 { +// let mut a: u32 = 0x9e3779b9 as u32 + std::mem::size_of::<u32>() as u32 + 3923095 as u32; +// let mut b = a; +// let c = a; + +// a = a.wrapping_add((PARTITION_HASH_SEED >> 32) as u32); +// b = b.wrapping_add(PARTITION_HASH_SEED as u32); +// let (mut a, b, c) = Self::mix(a, b, c); + +// a = a.wrapping_add(k); + +// let (_a, b, c) = Self::_final(a, b, c); + +// ((b as u64) << 32) | (c as u64) +// } diff --git a/lib/gat/sharding_test.go b/lib/gat/sharding_test.go new file mode 100644 index 0000000000000000000000000000000000000000..23bdd97bcb6617d4845cf2a2606d6b31acb406f9 --- /dev/null +++ b/lib/gat/sharding_test.go @@ -0,0 +1,61 @@ +package gat + +//TODO: convert test + +//#[cfg(test)] +//mod test { +// use super::*; +// +// // See tests/sharding/partition_hash_test_setup.sql +// // The output of those SELECT statements will match this test, +// // confirming that we implemented Postgres BIGINT hashing correctly. +// #[test] +// fn test_pg_bigint_hash() { +// let sharder = Sharder::new(5, ShardingFunction::PgBigintHash); +// +// let shard_0 = vec![1, 4, 5, 14, 19, 39, 40, 46, 47, 53]; +// +// for v in shard_0 { +// assert_eq!(sharder.shard(v), 0); +// } +// +// let shard_1 = vec![2, 3, 11, 17, 21, 23, 30, 49, 51, 54]; +// +// for v in shard_1 { +// assert_eq!(sharder.shard(v), 1); +// } +// +// let shard_2 = vec![6, 7, 15, 16, 18, 20, 25, 28, 34, 35]; +// +// for v in shard_2 { +// assert_eq!(sharder.shard(v), 2); +// } +// +// let shard_3 = vec![8, 12, 13, 22, 29, 31, 33, 36, 41, 43]; +// +// for v in shard_3 { +// assert_eq!(sharder.shard(v), 3); +// } +// +// let shard_4 = vec![9, 10, 24, 26, 27, 32, 37, 38, 42, 45]; +// +// for v in shard_4 { +// assert_eq!(sharder.shard(v), 4); +// } +// } +// +// #[test] +// fn test_sha1_hash() { +// let sharder = Sharder::new(12, ShardingFunction::Sha1); +// let ids = vec![ +// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, +// ]; +// let shards = vec![ +// 4, 7, 8, 3, 6, 0, 0, 10, 3, 11, 1, 7, 4, 4, 11, 2, 5, 0, 8, 3, +// ]; +// +// for (i, id) in ids.iter().enumerate() { +// assert_eq!(sharder.shard(*id), shards[i]); +// } +// } +//}