Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make client use .well-known redirects #233

Merged
merged 1 commit into from
May 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 185 additions & 44 deletions matrix_sdk/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,33 @@ pub enum LoopCtrl {
Break,
}

#[cfg(feature = "encryption")]
use matrix_sdk_common::{
api::r0::{
account::register,
device::{delete_devices, get_devices},
directory::{get_public_rooms, get_public_rooms_filtered},
filter::{create_filter::Request as FilterUploadRequest, FilterDefinition},
media::{create_content, get_content, get_content_thumbnail},
membership::{join_room_by_id, join_room_by_id_or_alias},
message::send_message_event,
profile::{get_avatar_url, get_display_name, set_avatar_url, set_display_name},
room::create_room,
session::{get_login_types, login, sso_login},
sync::sync_events,
uiaa::AuthData,
keys::{get_keys, upload_keys, upload_signing_keys::Request as UploadSigningKeysRequest},
to_device::send_event_to_device::{
Request as RumaToDeviceRequest, Response as ToDeviceResponse,
},
},
identifiers::EventId,
};
use matrix_sdk_common::{
api::{
r0::{
account::register,
device::{delete_devices, get_devices},
directory::{get_public_rooms, get_public_rooms_filtered},
filter::{create_filter::Request as FilterUploadRequest, FilterDefinition},
media::{create_content, get_content, get_content_thumbnail},
membership::{join_room_by_id, join_room_by_id_or_alias},
message::send_message_event,
profile::{get_avatar_url, get_display_name, set_avatar_url, set_display_name},
room::create_room,
session::{get_login_types, login, sso_login},
sync::sync_events,
uiaa::AuthData,
},
unversioned::{discover_homeserver, get_supported_versions},
},
assign,
identifiers::{DeviceIdBox, RoomId, RoomIdOrAliasId, ServerName, UserId},
Expand All @@ -98,16 +111,6 @@ use matrix_sdk_common::{
uuid::Uuid,
FromHttpResponseError, UInt,
};
#[cfg(feature = "encryption")]
use matrix_sdk_common::{
api::r0::{
keys::{get_keys, upload_keys, upload_signing_keys::Request as UploadSigningKeysRequest},
to_device::send_event_to_device::{
Request as RumaToDeviceRequest, Response as ToDeviceResponse,
},
},
identifiers::EventId,
};

#[cfg(feature = "encryption")]
use crate::{
Expand Down Expand Up @@ -142,7 +145,7 @@ const SSO_SERVER_BIND_TRIES: u8 = 10;
#[derive(Clone)]
pub struct Client {
/// The URL of the homeserver to connect to.
homeserver: Arc<Url>,
homeserver: Arc<RwLock<Url>>,
/// The underlying HTTP client.
http_client: HttpClient,
/// User session data.
Expand All @@ -164,7 +167,7 @@ pub struct Client {
#[cfg(not(tarpaulin_include))]
impl Debug for Client {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> StdResult<(), fmt::Error> {
write!(fmt, "Client {{ homeserver: {} }}", self.homeserver)
write!(fmt, "Client")
}
}

Expand Down Expand Up @@ -502,7 +505,7 @@ impl Client {
///
/// * `config` - Configuration for the client.
pub fn new_with_config(homeserver_url: Url, config: ClientConfig) -> Result<Self> {
let homeserver = Arc::new(homeserver_url);
let homeserver = Arc::new(RwLock::new(homeserver_url));

let client = if let Some(client) = config.client {
client
Expand All @@ -513,12 +516,8 @@ impl Client {
let base_client = BaseClient::new_with_config(config.base_config)?;
let session = base_client.session().clone();

let http_client = HttpClient {
homeserver: homeserver.clone(),
inner: client,
session,
request_config: config.request_config,
};
let http_client =
HttpClient::new(client, homeserver.clone(), session, config.request_config);

Ok(Self {
homeserver,
Expand All @@ -534,6 +533,89 @@ impl Client {
})
}

/// Creates a new client for making HTTP requests to the homeserver of the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably mention how it differs from the new() constructor. Explain which steps it takes to discover the homeserver and link to the relevant spec entry, an example wouldn't hurt either.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

/// given user. Follows homeserver discovery directions described
/// [here](https://spec.matrix.org/unstable/client-server-api/#well-known-uri).
///
/// # Arguments
///
/// * `user_id` - The id of the user whose homeserver the client should
/// connect to.
///
/// # Example
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{Client, identifiers::UserId};
/// # use futures::executor::block_on;
/// let alice = UserId::try_from("@alice:example.org").unwrap();
/// # block_on(async {
/// let client = Client::new_from_user_id(alice.clone()).await.unwrap();
/// client.login(alice.localpart(), "password", None, None).await.unwrap();
/// # });
/// ```
pub async fn new_from_user_id(user_id: UserId) -> Result<Self> {
let config = ClientConfig::new();
Client::new_from_user_id_with_config(user_id, config).await
}

/// Creates a new client for making HTTP requests to the homeserver of the
/// given user and configuration. Follows homeserver discovery directions
/// described [here](https://spec.matrix.org/unstable/client-server-api/#well-known-uri).
///
/// # Arguments
///
/// * `user_id` - The id of the user whose homeserver the client should
/// connect to.
///
/// * `config` - Configuration for the client.
pub async fn new_from_user_id_with_config(
user_id: UserId,
config: ClientConfig,
) -> Result<Self> {
let homeserver = Client::homeserver_from_user_id(user_id)?;
let mut client = Client::new_with_config(homeserver, config)?;

let well_known = client.discover_homeserver().await?;
let well_known = Url::parse(well_known.homeserver.base_url.as_ref())?;
client.set_homeserver(well_known).await;
client.get_supported_versions().await?;
Ok(client)
}

fn homeserver_from_user_id(user_id: UserId) -> Result<Url> {
let homeserver = format!("https://{}", user_id.server_name());
#[allow(unused_mut)]
let mut result = Url::parse(homeserver.as_str())?;
// Mockito only knows how to test http endpoints:
// https://github.com/lipanski/mockito/issues/127
#[cfg(test)]
let _ = result.set_scheme("http");
Ok(result)
}

async fn discover_homeserver(&self) -> Result<discover_homeserver::Response> {
self.send(discover_homeserver::Request::new(), Some(RequestConfig::new().disable_retry()))
.await
}

/// Change the homeserver URL used by this client.
///
/// # Arguments
///
/// * `homeserver_url` - The new URL to use.
pub async fn set_homeserver(&mut self, homeserver_url: Url) {
let mut homeserver = self.homeserver.write().await;
*homeserver = homeserver_url;
}

async fn get_supported_versions(&self) -> Result<get_supported_versions::Response> {
self.send(
get_supported_versions::Request::new(),
Some(RequestConfig::new().disable_retry()),
)
.await
}

/// Process a [transaction] received from the homeserver
///
/// # Arguments
Expand Down Expand Up @@ -566,8 +648,8 @@ impl Client {
}

/// The Homeserver of the client.
pub fn homeserver(&self) -> &Url {
&self.homeserver
pub async fn homeserver(&self) -> Url {
self.homeserver.read().await.clone()
}

/// Get the user id of the current owner of the client.
Expand Down Expand Up @@ -866,8 +948,8 @@ impl Client {
/// successful SSO login.
///
/// [`login_with_token`]: #method.login_with_token
pub fn get_sso_login_url(&self, redirect_url: &str) -> Result<String> {
let homeserver = self.homeserver();
pub async fn get_sso_login_url(&self, redirect_url: &str) -> Result<String> {
let homeserver = self.homeserver().await;
let request = sso_login::Request::new(redirect_url)
.try_into_http_request::<Vec<u8>>(homeserver.as_str(), SendAccessToken::None);
match request {
Expand Down Expand Up @@ -928,7 +1010,7 @@ impl Client {
device_id: Option<&str>,
initial_device_display_name: Option<&str>,
) -> Result<login::Response> {
info!("Logging in to {} as {:?}", self.homeserver, user);
info!("Logging in to {} as {:?}", self.homeserver().await, user);

let request = assign!(
login::Request::new(
Expand Down Expand Up @@ -1037,7 +1119,7 @@ impl Client {
where
C: Future<Output = Result<()>>,
{
info!("Logging in to {}", self.homeserver);
info!("Logging in to {}", self.homeserver().await);
let (signal_tx, signal_rx) = oneshot::channel();
let (data_tx, data_rx) = oneshot::channel();
let data_tx_mutex = Arc::new(std::sync::Mutex::new(Some(data_tx)));
Expand Down Expand Up @@ -1109,7 +1191,7 @@ impl Client {

tokio::spawn(server);

let sso_url = self.get_sso_login_url(redirect_url.as_str()).unwrap();
let sso_url = self.get_sso_login_url(redirect_url.as_str()).await.unwrap();

match use_sso_login_url(sso_url).await {
Ok(t) => t,
Expand Down Expand Up @@ -1193,7 +1275,7 @@ impl Client {
device_id: Option<&str>,
initial_device_display_name: Option<&str>,
) -> Result<login::Response> {
info!("Logging in to {}", self.homeserver);
info!("Logging in to {}", self.homeserver().await);

let request = assign!(
login::Request::new(
Expand Down Expand Up @@ -1264,7 +1346,7 @@ impl Client {
&self,
registration: impl Into<register::Request<'_>>,
) -> Result<register::Response> {
info!("Registering to {}", self.homeserver);
info!("Registering to {}", self.homeserver().await);

let request = registration.into();
self.send(request, None).await
Expand Down Expand Up @@ -2387,7 +2469,13 @@ impl Client {

#[cfg(test)]
mod test {
use std::{collections::BTreeMap, convert::TryInto, io::Cursor, str::FromStr, time::Duration};
use std::{
collections::BTreeMap,
convert::{TryFrom, TryInto},
io::Cursor,
str::FromStr,
time::Duration,
};

use matrix_sdk_base::identifiers::mxc_uri;
use matrix_sdk_common::{
Expand All @@ -2399,7 +2487,7 @@ mod test {
assign,
directory::Filter,
events::{room::message::MessageEventContent, AnyMessageEventContent},
identifiers::{event_id, room_id, user_id},
identifiers::{event_id, room_id, user_id, UserId},
thirdparty,
};
use matrix_sdk_test::{test_json, EventBuilder, EventsJson};
Expand All @@ -2425,6 +2513,59 @@ mod test {
client
}

#[tokio::test]
async fn set_homeserver() {
let homeserver = Url::from_str("http://example.com/").unwrap();

let mut client = Client::new(homeserver).unwrap();

let homeserver = Url::from_str(&mockito::server_url()).unwrap();

client.set_homeserver(homeserver.clone()).await;

assert_eq!(client.homeserver().await, homeserver);
}

#[tokio::test]
async fn successful_discovery() {
let server_url = mockito::server_url();
let domain = server_url.strip_prefix("http://").unwrap();
let alice = UserId::try_from("@alice:".to_string() + domain).unwrap();

let _m_well_known = mock("GET", "/.well-known/matrix/client")
.with_status(200)
.with_body(
test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()),
)
.create();

let _m_versions = mock("GET", "/_matrix/client/versions")
.with_status(200)
.with_body(test_json::VERSIONS.to_string())
.create();
let client = Client::new_from_user_id(alice).await.unwrap();

assert_eq!(client.homeserver().await, Url::parse(server_url.as_ref()).unwrap());
}

#[tokio::test]
async fn discovery_broken_server() {
let server_url = mockito::server_url();
let domain = server_url.strip_prefix("http://").unwrap();
let alice = UserId::try_from("@alice:".to_string() + domain).unwrap();

let _m = mock("GET", "/.well-known/matrix/client")
.with_status(200)
.with_body(
test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()),
)
.create();

if Client::new_from_user_id(alice).await.is_ok() {
panic!("Creating a client from a user ID should fail when the .well-known server returns no version infromation.");
}
}

#[tokio::test]
async fn login() {
let homeserver = Url::from_str(&mockito::server_url()).unwrap();
Expand Down Expand Up @@ -2514,7 +2655,7 @@ mod test {
.any(|flow| matches!(flow, LoginType::Sso(_)));
assert!(can_sso);

let sso_url = client.get_sso_login_url("http://127.0.0.1:3030");
let sso_url = client.get_sso_login_url("http://127.0.0.1:3030").await;
assert!(sso_url.is_ok());

let _m = mock("POST", "/_matrix/client/r0/login")
Expand Down Expand Up @@ -2626,7 +2767,7 @@ mod test {
client.base_client.receive_sync_response(response).await.unwrap();
let room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost");

assert_eq!(client.homeserver(), &Url::parse(&mockito::server_url()).unwrap());
assert_eq!(client.homeserver().await, Url::parse(&mockito::server_url()).unwrap());

let room = client.get_joined_room(&room_id);
assert!(room.is_some());
Expand Down
5 changes: 5 additions & 0 deletions matrix_sdk/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use matrix_sdk_common::{
use reqwest::Error as ReqwestError;
use serde_json::Error as JsonError;
use thiserror::Error;
use url::ParseError as UrlParseError;

/// Result type of the rust-sdk.
pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -128,6 +129,10 @@ pub enum Error {
/// An error encountered when trying to parse an identifier.
#[error(transparent)]
Identifier(#[from] IdentifierError),

/// An error encountered when trying to parse a url.
#[error(transparent)]
Url(#[from] UrlParseError),
}

impl Error {
Expand Down
Loading