use anyhow::Result;
use futures::{future::RemoteHandle, FutureExt};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::Display;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use tokio::net::{lookup_host, ToSocketAddrs};
use crate::cql_to_rust::{FromRow, FromRowError};
use crate::frame::response::result;
use crate::frame::response::Response;
use crate::frame::value::Value;
use crate::prepared_statement::PreparedStatement;
use crate::query::Query;
use crate::routing::{murmur3_token, Node, Shard, ShardInfo, Token};
use crate::transport::connection::{open_connection, Connection};
use crate::transport::iterator::RowIterator;
use crate::transport::topology::{Topology, TopologyReader};
use crate::transport::Compression;
const POOL_LOCK_POISONED: &str =
"Connection pool lock is poisoned, this session is no longer usable. \
Drop it and create a new one.";
pub struct Session {
pool: Arc<RwLock<HashMap<Node, NodePool>>>,
compression: Option<Compression>,
topology: Topology,
_topology_reader_handle: RemoteHandle<()>,
}
struct NodePool {
connections: HashMap<Shard, Arc<Connection>>,
shard_info: Option<ShardInfo>,
}
pub trait IntoTypedRows {
fn into_typed<RowT: FromRow>(self) -> TypedRowIter<RowT>;
}
impl IntoTypedRows for Vec<result::Row> {
fn into_typed<RowT: FromRow>(self) -> TypedRowIter<RowT> {
TypedRowIter {
row_iter: self.into_iter(),
phantom_data: Default::default(),
}
}
}
pub struct TypedRowIter<RowT: FromRow> {
row_iter: std::vec::IntoIter<result::Row>,
phantom_data: std::marker::PhantomData<RowT>,
}
impl<RowT: FromRow> Iterator for TypedRowIter<RowT> {
type Item = Result<RowT, FromRowError>;
fn next(&mut self) -> Option<Self::Item> {
self.row_iter.next().map(RowT::from_row)
}
}
impl Session {
pub async fn connect(
addr: impl ToSocketAddrs + Display,
compression: Option<Compression>,
) -> Result<Self> {
let addr = resolve(addr).await?;
let connection = open_connection(addr, None, compression).await?;
let node = Node { addr };
let (topology_reader, topology) = TopologyReader::new(node).await?;
let (fut, _topology_reader_handle) = topology_reader.run().remote_handle();
tokio::task::spawn(fut);
let pool = Arc::new(RwLock::new(
vec![(node, NodePool::new(Arc::new(connection)))]
.into_iter()
.collect(),
));
Ok(Session {
pool,
compression,
topology,
_topology_reader_handle,
})
}
pub async fn query(
&self,
query: impl Into<Query>,
values: &[Value],
) -> Result<Option<Vec<result::Row>>> {
self.any_connection()?
.query_single_page(query, values)
.await
}
pub fn query_iter(&self, query: impl Into<Query>, values: &[Value]) -> Result<RowIterator> {
Ok(RowIterator::new_for_query(
self.any_connection()?,
query.into(),
values.to_owned(),
))
}
pub async fn prepare(&self, query: &str) -> Result<PreparedStatement> {
let result = self.any_connection()?.prepare(query.to_owned()).await?;
match result {
Response::Error(err) => Err(err.into()),
Response::Result(result::Result::Prepared(p)) => Ok(PreparedStatement::new(
p.id,
p.prepared_metadata,
query.to_owned(),
)),
_ => Err(anyhow!("Unexpected frame received")),
}
}
pub async fn execute(
&self,
prepared: &PreparedStatement,
values: &[Value],
) -> Result<Option<Vec<result::Row>>> {
let token = calculate_token(prepared, values);
let connection = self.pick_connection(token).await?;
let result = connection.execute(prepared, values, None).await?;
match result {
Response::Error(err) => {
match err.code {
9472 => {
let reprepared_result = connection
.prepare(prepared.get_statement().to_owned())
.await?;
match reprepared_result {
Response::Error(err) => return Err(err.into()),
Response::Result(result::Result::Prepared(reprepared)) => {
if reprepared.id != prepared.get_id() {
return Err(anyhow!(
"Reprepared statement unexpectedly changed its id"
));
}
}
_ => return Err(anyhow!("Unexpected frame received")),
}
let result = connection.execute(prepared, values, None).await?;
match result {
Response::Error(err) => Err(err.into()),
Response::Result(result::Result::Rows(rs)) => Ok(Some(rs.rows)),
Response::Result(_) => Ok(None),
_ => Err(anyhow!("Unexpected frame received")),
}
}
_ => Err(err.into()),
}
}
Response::Result(result::Result::Rows(rs)) => Ok(Some(rs.rows)),
Response::Result(_) => Ok(None),
_ => Err(anyhow!("Unexpected frame received")),
}
}
pub fn execute_iter(
&self,
prepared: impl Into<PreparedStatement>,
values: &[Value],
) -> Result<RowIterator> {
Ok(RowIterator::new_for_prepared_statement(
self.any_connection()?,
prepared.into(),
values.to_owned(),
))
}
pub fn get_connections(&self) -> Result<Vec<(Node, Vec<Arc<Connection>>)>> {
Ok(self
.pool
.read()
.map_err(|_| anyhow!(POOL_LOCK_POISONED))?
.iter()
.map(|(&node, node_pool)| (node, node_pool.get_connections()))
.collect())
}
pub async fn refresh_topology(&self) -> Result<()> {
self.topology.refresh().await
}
async fn pick_connection(&self, t: Token) -> Result<Arc<Connection>> {
let owner = self.topology.read_ring()?.owner(t);
let mut found_pool = false;
let mut shard_info = None;
if let Some(node_pool) = self
.pool
.read()
.map_err(|_| anyhow!(POOL_LOCK_POISONED))?
.get(&owner)
{
if let Some(c) = node_pool.connection_for_token(t) {
return Ok(c);
}
found_pool = true;
shard_info = node_pool.get_shard_info().cloned();
};
if !found_pool {
let new_conn = open_connection(owner.addr, None, self.compression).await?;
let new_node_pool = NodePool::new(Arc::new(new_conn));
return Ok(self
.pool
.write()
.map_err(|_| anyhow!(POOL_LOCK_POISONED))?
.entry(owner)
.or_insert(new_node_pool)
.connection_for_token_or_any(t)?);
};
let new_conn_source_port = shard_info
.as_ref()
.map(|info| info.draw_source_port_for_token(t));
let new_conn =
Arc::new(open_connection(owner.addr, new_conn_source_port, self.compression).await?);
Ok(
match self
.pool
.write()
.map_err(|_| anyhow!(POOL_LOCK_POISONED))?
.entry(owner)
{
Entry::Vacant(entry) => {
entry.insert(NodePool::new(new_conn.clone()));
new_conn
}
Entry::Occupied(mut entry) => {
let node_pool = entry.get_mut();
if shard_info.as_ref() != node_pool.get_shard_info() {
return node_pool.connection_for_token_or_any(t);
}
let shard = node_pool.shard_for_token(t);
node_pool.or_insert(shard, new_conn)
}
},
)
}
fn any_connection(&self) -> Result<Arc<Connection>> {
Ok(self
.pool
.read()
.map_err(|_| anyhow!(POOL_LOCK_POISONED))?
.values()
.next()
.ok_or_else(|| anyhow!("fatal error, broken invariant: no connections available"))?
.any_connection()?)
}
}
impl NodePool {
pub fn new(connection: Arc<Connection>) -> Self {
let (shard, shard_info) =
match (connection.get_shard_info(), connection.get_is_shard_aware()) {
(Some(info), true) => (
info.shard_of_source_port(connection.get_source_port()),
Some(info.to_owned()),
),
_ => (0, None),
};
Self {
connections: vec![(shard, connection)].into_iter().collect(),
shard_info,
}
}
pub fn get_connections(&self) -> Vec<Arc<Connection>> {
self.connections
.values()
.map(|conn| conn.to_owned())
.collect()
}
pub fn get_shard_info(&self) -> Option<&ShardInfo> {
self.shard_info.as_ref()
}
pub fn shard_for_token(&self, t: Token) -> Shard {
self.shard_info.as_ref().map_or(0, |info| info.shard_of(t))
}
pub fn connection_for_token(&self, t: Token) -> Option<Arc<Connection>> {
let shard = self.shard_for_token(t);
self.connections.get(&shard).map(|c| c.to_owned())
}
pub fn connection_for_token_or_any(&self, t: Token) -> Result<Arc<Connection>> {
self.connection_for_token(t)
.map_or_else(|| self.any_connection(), Ok)
}
pub fn any_connection(&self) -> Result<Arc<Connection>> {
Ok(self
.connections
.values()
.next()
.ok_or_else(|| anyhow!("fatal error, broken invariant: no connections available"))?
.to_owned())
}
pub fn or_insert(&mut self, shard: Shard, conn: Arc<Connection>) -> Arc<Connection> {
self.connections.entry(shard).or_insert(conn).to_owned()
}
}
fn calculate_token<'a>(stmt: &PreparedStatement, values: &'a [Value]) -> Token {
murmur3_token(stmt.compute_partition_key(values))
}
async fn resolve(addr: impl ToSocketAddrs + Display) -> Result<SocketAddr> {
let failed_err = anyhow!("failed to resolve {}", addr);
let mut ret = None;
for a in lookup_host(addr).await? {
match a {
SocketAddr::V4(_) => return Ok(a),
_ => {
ret = Some(a);
}
}
}
ret.ok_or(failed_err)
}