This commit is contained in:
2022-12-30 20:55:14 +08:00
parent 5a9c09d673
commit 118d6a5a1d
53 changed files with 4720 additions and 1 deletions

156
src/auth.rs Normal file
View File

@@ -0,0 +1,156 @@
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use crate::config::{AuthType, Config, Location};
use headers::{authorization::Basic, Authorization, HeaderMapExt};
use http::status::StatusCode;
type HttpRequest = http::Request<hyper::Body>;
#[derive(Clone)]
pub struct Auth {
config: Arc<Config>,
#[cfg(feature = "pam")]
pam_auth: pam_sandboxed::PamAuth,
}
impl Auth {
pub fn new(config: Arc<Config>) -> io::Result<Auth> {
// initialize pam.
#[cfg(feature = "pam")]
let pam_auth = {
// set cache timeouts.
if let Some(timeout) = config.pam.cache_timeout {
crate::cache::cached::set_pamcache_timeout(timeout);
}
pam_sandboxed::PamAuth::new(config.pam.threads.clone())?
};
Ok(Auth {
#[cfg(feature = "pam")]
pam_auth,
config,
})
}
// authenticate user.
pub async fn auth<'a>(
&'a self,
req: &'a HttpRequest,
location: &Location,
_remote_ip: SocketAddr,
) -> Result<String, StatusCode>
{
// we must have a login/pass
let basic = match req.headers().typed_get::<Authorization<Basic>>() {
Some(Authorization(basic)) => basic,
_ => return Err(StatusCode::UNAUTHORIZED),
};
let user = basic.username();
let pass = basic.password();
// match the auth type.
let auth_type = location
.accounts
.auth_type
.as_ref()
.or(self.config.accounts.auth_type.as_ref());
match auth_type {
#[cfg(feature = "pam")]
Some(&AuthType::Pam) => self.auth_pam(req, user, pass, _remote_ip).await,
Some(&AuthType::HtPasswd(ref ht)) => self.auth_htpasswd(user, pass, ht.as_str()).await,
None => {
debug!("need authentication, but auth-type is not set");
Err(StatusCode::UNAUTHORIZED)
},
}
}
// authenticate user using PAM.
#[cfg(feature = "pam")]
async fn auth_pam<'a>(
&'a self,
req: &'a HttpRequest,
user: &'a str,
pass: &'a str,
remote_ip: SocketAddr,
) -> Result<String, StatusCode>
{
// stringify the remote IP address.
let ip = remote_ip.ip();
let ip_string = if ip.is_loopback() {
// if it's loopback, take the value from the x-forwarded-for
// header, if present.
req.headers()
.get("x-forwarded-for")
.and_then(|s| s.to_str().ok())
.and_then(|s| s.split(',').next())
.map(|s| s.trim().to_owned())
} else {
Some(match ip {
std::net::IpAddr::V4(ip) => ip.to_string(),
std::net::IpAddr::V6(ip) => ip.to_string(),
})
};
let ip_ref = ip_string.as_ref().map(|s| s.as_str());
// authenticate.
let service = self.config.pam.service.as_str();
let pam_auth = self.pam_auth.clone();
match crate::cache::cached::pam_auth(pam_auth, service, user, pass, ip_ref).await {
Ok(_) => Ok(user.to_string()),
Err(_) => {
debug!(
"auth_pam({}): authentication for {} ({:?}) failed",
service, user, ip_ref
);
Err(StatusCode::UNAUTHORIZED)
},
}
}
// authenticate user using htpasswd.
async fn auth_htpasswd<'a>(
&'a self,
user: &'a str,
pass: &'a str,
section: &'a str,
) -> Result<String, StatusCode>
{
// Get the htpasswd.WHATEVER section from the config file.
let file = match self.config.htpasswd.get(section) {
Some(section) => section.htpasswd.as_str(),
None => return Err(StatusCode::UNAUTHORIZED),
};
// Read the file and split it into a bunch of lines.
tokio::task::block_in_place(move || {
let data = match std::fs::read_to_string(file) {
Ok(data) => data,
Err(e) => {
debug!("{}: {}", file, e);
return Err(StatusCode::UNAUTHORIZED);
},
};
let lines = data
.split('\n')
.map(|s| s.trim())
.filter(|s| !s.starts_with("#") && !s.is_empty());
// Check each line for a match.
for line in lines {
let mut fields = line.split(':');
if let (Some(htuser), Some(htpass)) = (fields.next(), fields.next()) {
if htuser == user && pwhash::unix::verify(pass, htpass) {
return Ok(user.to_string());
}
}
}
debug!("auth_htpasswd: authentication for {} failed", user);
Err(StatusCode::UNAUTHORIZED)
})
}
}

185
src/cache.rs Normal file
View File

@@ -0,0 +1,185 @@
use std::borrow::Borrow;
use std::cmp::Eq;
use std::collections::vec_deque::VecDeque;
use std::collections::HashMap;
use std::hash::Hash;
use std::option::Option;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[allow(dead_code)]
pub struct Cache<K, V> {
intern: Mutex<Intern<K, V>>,
}
struct Intern<K, V> {
maxsize: usize,
maxage: Duration,
map: HashMap<K, Arc<V>>,
fifo: VecDeque<(Instant, K)>,
}
impl<K: Hash + Eq + Clone, V> Cache<K, V> {
pub fn new() -> Cache<K, V> {
let i = Intern {
maxsize: 0,
maxage: Duration::new(0, 0),
map: HashMap::new(),
fifo: VecDeque::new(),
};
Cache {
intern: Mutex::new(i),
}
}
#[allow(dead_code)]
pub fn maxsize(self, maxsize: usize) -> Self {
self.intern.lock().unwrap().maxsize = maxsize;
self
}
#[allow(dead_code)]
pub fn maxage(self, maxage: Duration) -> Self {
self.intern.lock().unwrap().maxage = maxage;
self
}
fn expire(&self, m: &mut Intern<K, V>) {
let mut n = m.fifo.len();
if m.maxsize > 0 && n >= m.maxsize {
n = m.maxsize;
}
if m.maxage.as_secs() > 0 || m.maxage.subsec_nanos() > 0 {
let now = Instant::now();
while n > 0 {
let &(t, _) = m.fifo.get(n - 1).unwrap();
if now.duration_since(t) <= m.maxage {
break;
}
n -= 1;
}
}
for x in n..m.fifo.len() {
let &(_, ref key) = m.fifo.get(x).unwrap();
m.map.remove(&key);
}
m.fifo.truncate(n);
}
pub fn insert(&self, key: K, val: V) -> Arc<V> {
let mut m = self.intern.lock().unwrap();
self.expire(&mut *m);
let av = Arc::new(val);
let ac = av.clone();
m.map.insert(key.clone(), av);
m.fifo.push_front((Instant::now(), key));
ac
}
// see https://doc.rust-lang.org/book/first-edition/borrow-and-asref.html
pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<Arc<V>>
where
K: Borrow<Q>,
Q: Hash + Eq,
{
let mut m = self.intern.lock().unwrap();
self.expire(&mut *m);
if let Some(v) = m.map.get(key) {
return Some(v.clone());
}
None
}
}
pub(crate) mod cached {
//
// Cached versions of Unix account lookup and Pam auth.
//
use std::io;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::cache;
use crate::unixuser::{self, User};
use lazy_static::lazy_static;
struct Timeouts {
pwcache: Duration,
pamcache: Duration,
}
lazy_static! {
static ref TIMEOUTS: Mutex<Timeouts> = Mutex::new(Timeouts {
pwcache: Duration::new(120, 0),
pamcache: Duration::new(120, 0),
});
static ref PWCACHE: cache::Cache<String, unixuser::User> = new_pwcache();
static ref PAMCACHE: cache::Cache<u64, String> = new_pamcache();
}
fn new_pwcache() -> cache::Cache<String, unixuser::User> {
let timeouts = TIMEOUTS.lock().unwrap();
cache::Cache::new().maxage(timeouts.pwcache)
}
fn new_pamcache() -> cache::Cache<u64, String> {
let timeouts = TIMEOUTS.lock().unwrap();
cache::Cache::new().maxage(timeouts.pamcache)
}
pub(crate) fn set_pwcache_timeout(secs: usize) {
let mut timeouts = TIMEOUTS.lock().unwrap();
timeouts.pwcache = Duration::new(secs as u64, 0);
}
#[cfg(feature = "pam")]
pub(crate) fn set_pamcache_timeout(secs: usize) {
let mut timeouts = TIMEOUTS.lock().unwrap();
timeouts.pamcache = Duration::new(secs as u64, 0);
}
#[cfg(feature = "pam")]
pub async fn pam_auth<'a>(
pam_auth: pam_sandboxed::PamAuth,
service: &'a str,
user: &'a str,
pass: &'a str,
remip: Option<&'a str>,
) -> Result<(), pam_sandboxed::PamError>
{
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut s = DefaultHasher::new();
service.hash(&mut s);
user.hash(&mut s);
pass.hash(&mut s);
remip.as_ref().hash(&mut s);
let key = s.finish();
if let Some(cache_user) = PAMCACHE.get(&key) {
if user == cache_user.as_str() {
return Ok(());
}
}
let mut pam_auth = pam_auth;
match pam_auth.auth(&service, &user, &pass, remip).await {
Err(e) => Err(e),
Ok(()) => {
PAMCACHE.insert(key, user.to_owned());
Ok(())
},
}
}
pub async fn unixuser(username: &str, with_groups: bool) -> Result<Arc<User>, io::Error> {
if let Some(pwd) = PWCACHE.get(username) {
return Ok(pwd);
}
match User::by_name_async(username, with_groups).await {
Err(e) => Err(e),
Ok(pwd) => Ok(PWCACHE.insert(username.to_owned(), pwd)),
}
}
}

339
src/config.rs Normal file
View File

@@ -0,0 +1,339 @@
use std::collections::HashMap;
use std::net::{SocketAddr, ToSocketAddrs};
use std::path::Path;
use std::process::exit;
use std::{fs, io};
use enum_from_str::ParseEnumVariantError;
use enum_from_str_derive::FromStr;
use serde::{Deserialize, Deserializer};
use toml;
use webdav_handler::DavMethodSet;
use crate::router::Router;
#[derive(Deserialize, Debug)]
pub struct Config {
pub server: Server,
#[serde(default)]
pub accounts: Accounts,
#[serde(default)]
pub pam: Pam,
#[serde(default)]
pub htpasswd: HashMap<String, HtPasswd>,
#[serde(default)]
pub unix: Unix,
#[serde(default)]
pub location: Vec<Location>,
#[serde(skip)]
pub router: Router<usize>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct Server {
#[serde(default)]
pub listen: OneOrManyAddr,
#[serde(default)]
pub tls_listen: OneOrManyAddr,
#[serde(default)]
pub tls_key: Option<String>,
#[serde(default)]
pub tls_cert: Option<String>,
//#[serde(deserialize_with = "deserialize_user", default)]
pub uid: Option<u32>,
//#[serde(deserialize_with = "deserialize_group", default)]
pub gid: Option<u32>,
#[serde(default)]
pub identification: Option<String>,
#[serde(default)]
pub cors: bool,
}
#[derive(Deserialize, Debug, Clone, Default)]
pub struct Accounts {
#[serde(rename = "auth-type", deserialize_with = "deserialize_authtype", default)]
pub auth_type: Option<AuthType>,
#[serde(rename = "acct-type", deserialize_with = "deserialize_opt_enum", default)]
pub acct_type: Option<AcctType>,
#[serde(default)]
pub realm: Option<String>,
}
#[derive(Deserialize, Debug, Clone, Default)]
pub struct Pam {
pub service: String,
#[serde(rename = "cache-timeout")]
pub cache_timeout: Option<usize>,
pub threads: Option<usize>,
}
#[derive(Deserialize, Debug, Clone, Default)]
pub struct HtPasswd {
pub htpasswd: String,
}
#[derive(Deserialize, Debug, Clone, Default)]
pub struct Unix {
#[serde(rename = "cache-timeout")]
pub cache_timeout: Option<usize>,
#[serde(rename = "min-uid", default)]
pub min_uid: Option<u32>,
#[serde(rename = "supplementary-groups", default)]
pub aux_groups: bool,
}
#[derive(Deserialize, Debug, Clone)]
pub struct Location {
#[serde(default)]
pub route: Vec<String>,
#[serde(deserialize_with = "deserialize_methodset", default)]
pub methods: Option<DavMethodSet>,
#[serde(deserialize_with = "deserialize_opt_enum", default)]
pub auth: Option<Auth>,
#[serde(default, flatten)]
pub accounts: Accounts,
#[serde(deserialize_with = "deserialize_enum")]
pub handler: Handler,
#[serde(default)]
pub setuid: bool,
pub directory: String,
#[serde(default, alias = "hide-symlinks")]
pub hide_symlinks: Option<bool>,
#[serde(default)]
pub indexfile: Option<String>,
#[serde(default)]
pub autoindex: bool,
#[serde(
rename = "case-insensitive",
deserialize_with = "deserialize_opt_enum",
default
)]
pub case_insensitive: Option<CaseInsensitive>,
#[serde(deserialize_with = "deserialize_opt_enum", default)]
pub on_notfound: Option<OnNotfound>,
}
#[derive(FromStr, Debug, Clone, Copy)]
pub enum Handler {
#[from_str = "virtroot"]
Virtroot,
#[from_str = "filesystem"]
Filesystem,
}
#[derive(FromStr, Debug, Clone, Copy)]
pub enum Auth {
#[from_str = "false"]
False,
#[from_str = "true"]
True,
#[from_str = "opportunistic"]
Opportunistic,
#[from_str = "write"]
Write,
}
#[derive(Debug, Clone)]
pub enum AuthType {
#[cfg(feature = "pam")]
Pam,
HtPasswd(String),
}
#[derive(FromStr, Debug, Clone, Copy)]
pub enum AcctType {
#[from_str = "unix"]
Unix,
}
#[derive(FromStr, Debug, Clone, Copy)]
pub enum CaseInsensitive {
#[from_str = "true"]
True,
#[from_str = "ms"]
Ms,
#[from_str = "false"]
False,
}
#[derive(FromStr, Debug, Clone, Copy)]
pub enum OnNotfound {
#[from_str = "continue"]
Continue,
#[from_str = "return"]
Return,
}
#[derive(Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum OneOrManyAddr {
One(SocketAddr),
Many(Vec<SocketAddr>),
}
impl OneOrManyAddr {
pub fn is_empty(&self) -> bool {
match self {
OneOrManyAddr::One(_) => false,
OneOrManyAddr::Many(v) => v.is_empty(),
}
}
}
impl Default for OneOrManyAddr {
fn default() -> Self {
OneOrManyAddr::Many(Vec::new())
}
}
impl ToSocketAddrs for OneOrManyAddr {
type Iter = std::vec::IntoIter<SocketAddr>;
fn to_socket_addrs(&self) -> io::Result<std::vec::IntoIter<SocketAddr>> {
let i = match self {
OneOrManyAddr::Many(ref v) => v.to_owned(),
OneOrManyAddr::One(ref s) => vec![*s],
};
Ok(i.into_iter())
}
}
// keep this here for now, we might implement a enum{(u32, String} later for
// usernames and groupnames.
#[allow(unused)]
pub fn deserialize_user<'de, D>(deserializer: D) -> Result<Option<u32>, D::Error>
where D: Deserializer<'de> {
let s = String::deserialize(deserializer)?;
s.parse::<u32>()
.map(|v| Some(v))
.map_err(serde::de::Error::custom)
}
#[allow(unused)]
pub fn deserialize_group<'de, D>(deserializer: D) -> Result<Option<u32>, D::Error>
where D: Deserializer<'de> {
let s = String::deserialize(deserializer)?;
s.parse::<u32>()
.map(|v| Some(v))
.map_err(serde::de::Error::custom)
}
pub fn deserialize_methodset<'de, D>(deserializer: D) -> Result<Option<DavMethodSet>, D::Error>
where D: Deserializer<'de> {
let m = Vec::<String>::deserialize(deserializer)?;
DavMethodSet::from_vec(m)
.map(|v| Some(v))
.map_err(serde::de::Error::custom)
}
pub fn deserialize_authtype<'de, D>(deserializer: D) -> Result<Option<AuthType>, D::Error>
where D: Deserializer<'de> {
let s = String::deserialize(deserializer)?;
if s.starts_with("htpasswd.") {
return Ok(Some(AuthType::HtPasswd(s[9..].to_string())));
}
#[cfg(feature = "pam")]
if &s == "pam" {
return Ok(Some(AuthType::Pam));
}
if s == "" {
return Ok(None);
}
Err(serde::de::Error::custom("unknown auth-type"))
}
pub fn deserialize_opt_enum<'de, D, E>(deserializer: D) -> Result<Option<E>, D::Error>
where
D: Deserializer<'de>,
E: std::str::FromStr,
E::Err: std::fmt::Display,
{
String::deserialize(deserializer)?
.as_str()
.parse::<E>()
.map(|e| Some(e))
.map_err(serde::de::Error::custom)
}
pub fn deserialize_enum<'de, D, E>(deserializer: D) -> Result<E, D::Error>
where
D: Deserializer<'de>,
E: std::str::FromStr,
E::Err: std::fmt::Display,
{
String::deserialize(deserializer)?
.as_str()
.parse::<E>()
.map_err(serde::de::Error::custom)
}
// Read the TOML config into a config::Config struct.
pub fn read(toml_file: impl AsRef<Path>) -> io::Result<Config> {
let buffer = fs::read_to_string(&toml_file)?;
// initial parse.
let config: Config = match toml::from_str(&buffer) {
Ok(v) => Ok(v),
Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
}?;
Ok(config)
}
pub fn build_routes(cfg: &str, config: &mut Config) -> io::Result<()> {
let mut builder = Router::builder();
for (idx, location) in config.location.iter().enumerate() {
for r in &location.route {
if let Err(e) = builder.add(r, location.methods.clone(), idx) {
let msg = format!("{}: [[location]][{}]: route {}: {}", cfg, idx, r, e);
return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
}
}
}
config.router = builder.build();
Ok(())
}
pub fn check(cfg: &str, config: &Config) {
#[cfg(feature = "pam")]
if let Some(AuthType::Pam) = config.accounts.auth_type {
if config.pam.service == "" {
eprintln!("{}: missing section [pam]", cfg);
exit(1);
}
}
if config.server.listen.is_empty() && config.server.tls_listen.is_empty() {
eprintln!("{}: [server]: at least one of listen or tls_listen must be set", cfg);
exit(1);
}
if !config.server.tls_listen.is_empty() {
if config.server.tls_cert.is_none() {
eprintln!("{}: [server]: tls_cert not set", cfg);
exit(1);
}
if config.server.tls_key.is_none() {
eprintln!("{}: [server]: tls_key not set", cfg);
exit(1);
}
}
for (idx, location) in config.location.iter().enumerate() {
if location.setuid {
if !crate::suid::has_thread_switch_ugid() {
eprintln!(
"{}: [[location]][{}]: setuid: uid switching not supported on this OS",
cfg, idx
);
exit(1);
}
if config.server.uid.is_none() || config.server.gid.is_none() {
eprintln!("{}: [server]: missing uid and/or gid", cfg);
exit(1);
}
if config.accounts.acct_type.is_none() && location.accounts.acct_type.is_none() {
eprintln!("{}: [[location]][{}]: setuid: no acct-type set", cfg, idx);
exit(1);
}
}
}
}

635
src/main.rs Normal file
View File

@@ -0,0 +1,635 @@
#![doc(html_root_url = "https://docs.rs/webdav-server/0.4.0")]
//! # `webdav-server` is a webdav server that handles user-accounts.
//!
//! This is a webdav server that allows access to a users home directory,
//! just like an ancient FTP server would (remember those?).
//!
//! This is an application. There is no API documentation here.
//! If you want to build your _own_ webdav server, use the `webdav-handler` crate.
//!
//! See the [GitHub repository](https://github.com/miquels/webdav-server-rs/)
//! for documentation on how to run the server.
//!
#[macro_use]
extern crate log;
mod auth;
mod cache;
mod config;
mod rootfs;
#[doc(hidden)]
pub mod router;
mod suid;
mod tls;
mod unixuser;
mod userfs;
use std::convert::TryFrom;
use std::io;
use std::net::{SocketAddr, ToSocketAddrs};
use std::os::unix::io::{FromRawFd, AsRawFd};
use std::process::exit;
use std::sync::Arc;
use clap::clap_app;
use headers::{authorization::Basic, Authorization, HeaderMapExt};
use http::status::StatusCode;
use hyper::{
self,
server::conn::{AddrIncoming, AddrStream},
service::{make_service_fn, service_fn},
};
use tls_listener::TlsListener;
use tokio_rustls::server::TlsStream;
use webdav_handler::{davpath::DavPath, DavConfig, DavHandler, DavMethod, DavMethodSet};
use webdav_handler::{fakels::FakeLs, fs::DavFileSystem, ls::DavLockSystem};
use crate::config::{AcctType, Auth, CaseInsensitive, Handler, Location, OnNotfound};
use crate::rootfs::RootFs;
use crate::router::MatchedRoute;
use crate::suid::proc_switch_ugid;
use crate::tls::tls_acceptor;
use crate::userfs::UserFs;
static PROGNAME: &'static str = "webdav-server";
// Contains "state" and a handle to the config.
#[derive(Clone)]
struct Server {
dh: DavHandler,
auth: auth::Auth,
config: Arc<config::Config>,
}
type HttpResult = Result<hyper::Response<webdav_handler::body::Body>, io::Error>;
type HttpRequest = http::Request<hyper::Body>;
// Server implementation.
impl Server {
// Constructor.
pub fn new(config: Arc<config::Config>, auth: auth::Auth) -> Self {
// mostly empty handler.
let ls = FakeLs::new() as Box<dyn DavLockSystem>;
let dh = DavHandler::builder().locksystem(ls).build_handler();
Server { dh, auth, config }
}
// check user account.
async fn acct<'a>(
&'a self,
location: &Location,
auth_user: Option<&'a String>,
user_param: Option<&'a str>,
) -> Result<Option<Arc<unixuser::User>>, StatusCode>
{
// Get username - if any.
let user = match auth_user.map(|u| u.as_str()).or(user_param) {
Some(u) => u,
None => return Ok(None),
};
// If account is not set, fine.
let acct_type = location
.accounts
.acct_type
.as_ref()
.or(self.config.accounts.acct_type.as_ref());
match acct_type {
Some(&AcctType::Unix) => {},
None => return Ok(None),
};
// check if user exists.
let pwd = match cache::cached::unixuser(user, self.config.unix.aux_groups).await {
Ok(pwd) => pwd,
Err(_) => {
debug!("acct: unix: user {} not found", user);
return Err(StatusCode::UNAUTHORIZED);
},
};
// check minimum uid
if let Some(min_uid) = self.config.unix.min_uid {
if pwd.uid < min_uid {
debug!("acct: {}: uid {} too low (<{})", pwd.name, pwd.uid, min_uid);
return Err(StatusCode::FORBIDDEN);
}
}
Ok(Some(pwd))
}
// return a new response::Builder with the Server and CORS header set.
fn response_builder(&self) -> http::response::Builder {
let mut builder = hyper::Response::builder();
self.set_headers(builder.headers_mut().unwrap());
builder
}
// Set Server: webdav-server-rs header, and CORS.
fn set_headers(&self, headers: &mut http::HeaderMap<http::header::HeaderValue>) {
let id = self
.config
.server
.identification
.as_ref()
.map(|s| s.as_str())
.unwrap_or("webdav-server-rs");
if id != "" {
headers.insert("server", id.parse().unwrap());
}
if self.config.server.cors {
headers.insert("Access-Control-Allow-Origin", "*".parse().unwrap());
headers.insert("Access-Control-Allow-Methods", "GET,HEAD,OPTIONS,PROPFIND".parse().unwrap());
headers.insert("Access-Control-Allow-Headers", "DNT,Depth,Range".parse().unwrap());
}
}
// handle a request.
async fn route(&self, req: HttpRequest, remote_ip: SocketAddr) -> HttpResult {
// Get the URI path.
let davpath = match DavPath::from_uri(req.uri()) {
Ok(p) => p,
Err(_) => return self.error(StatusCode::BAD_REQUEST).await,
};
let path = davpath.as_bytes();
// Get the method.
let method = match DavMethod::try_from(req.method()) {
Ok(m) => m,
Err(_) => return self.error(http::StatusCode::METHOD_NOT_ALLOWED).await,
};
// Request is stored here.
let mut reqdata = Some(req);
let mut got_match = false;
// Match routes to one or more locations.
for route in self
.config
.router
.matches(path, method, &["user", "path"])
.drain(..)
{
got_match = true;
// Take the request from the option.
let req = reqdata.take().unwrap();
// if we might continue, store a clone of the request for the next round.
let location = &self.config.location[*route.data];
if let Some(OnNotfound::Continue) = location.on_notfound {
reqdata.get_or_insert(clone_httpreq(&req));
}
// handle request.
let res = self
.handle(req, method, path, route, location, remote_ip.clone())
.await?;
// no on_notfound? then this is final.
if reqdata.is_none() || res.status() != StatusCode::NOT_FOUND {
return Ok(res);
}
}
if !got_match {
debug!("route: no matching route for {:?}", davpath);
}
self.error(StatusCode::NOT_FOUND).await
}
// handle a request.
async fn handle<'a, 't: 'a, 'p: 'a>(
&'a self,
req: HttpRequest,
method: DavMethod,
path: &'a [u8],
route: MatchedRoute<'t, 'p, usize>,
location: &'a Location,
remote_ip: SocketAddr,
) -> HttpResult
{
// See if we matched a :user parameter
// If so, it must be valid UTF-8, or we return NOT_FOUND.
let user_param = match route.params[0].as_ref() {
Some(p) => {
match p.as_str() {
Some(p) => Some(p),
None => {
debug!("handle: invalid utf-8 in :user part of path");
return self.error(StatusCode::NOT_FOUND).await;
},
}
},
None => None,
};
// Do authentication if needed.
let auth_hdr = req.headers().typed_get::<Authorization<Basic>>();
let do_auth = match location.auth {
Some(Auth::True) => true,
Some(Auth::Write) => !DavMethodSet::WEBDAV_RO.contains(method) || auth_hdr.is_some(),
Some(Auth::False) => false,
Some(Auth::Opportunistic) | None => auth_hdr.is_some(),
};
let auth_user = if do_auth {
let user = match self.auth.auth(&req, location, remote_ip).await {
Ok(user) => user,
Err(status) => return self.auth_error(status, location).await,
};
// if there was a :user in the route, return error if it does not match.
if user_param.map(|u| u != &user).unwrap_or(false) {
debug!("handle: auth user and :user mismatch");
return self.auth_error(StatusCode::UNAUTHORIZED, location).await;
}
Some(user)
} else {
None
};
// Now see if we want to do a account lookup, for uid/gid/homedir.
let pwd = match self.acct(location, auth_user.as_ref(), user_param).await {
Ok(pwd) => pwd,
Err(status) => return self.auth_error(status, location).await,
};
// Expand "~" in the directory.
let dir = match expand_directory(location.directory.as_str(), pwd.as_ref()) {
Ok(d) => d,
Err(_) => return self.error(StatusCode::NOT_FOUND).await,
};
// If :path matched, we can calculate the prefix.
// If it didn't, the entire path _is_ the prefix.
let prefix = match route.params[1].as_ref() {
Some(p) => {
let mut start = p.start();
if start > 0 {
start -= 1;
}
&path[..start]
},
None => path,
};
let prefix = match std::str::from_utf8(prefix) {
Ok(p) => p.to_string(),
Err(_) => {
debug!("handle: prefix is non-UTF8");
return self.error(StatusCode::NOT_FOUND).await;
},
};
// Get User-Agent for user-agent specific modes.
let user_agent = req
.headers()
.get("user-agent")
.and_then(|s| s.to_str().ok())
.unwrap_or("");
// Case insensitivity wanted?
let case_insensitive = match location.case_insensitive {
Some(CaseInsensitive::True) => true,
Some(CaseInsensitive::Ms) => user_agent.contains("Microsoft"),
Some(CaseInsensitive::False) | None => false,
};
// macOS optimizations?
let macos = user_agent.contains("WebDAVFS/") && user_agent.contains("Darwin");
// Get the filesystem.
let auth_ugid = if location.setuid {
pwd.as_ref().map(|p| (p.uid, p.gid, p.groups.as_slice()))
} else {
None
};
let fs = match location.handler {
Handler::Virtroot => {
let auth_user = auth_user.as_ref().map(String::to_owned);
RootFs::new(dir, auth_user, auth_ugid) as Box<dyn DavFileSystem>
},
Handler::Filesystem => {
UserFs::new(dir, auth_ugid, true, case_insensitive, macos) as Box<dyn DavFileSystem>
},
};
// Build a handler.
let methods = location
.methods
.unwrap_or(DavMethodSet::from_vec(vec!["GET", "HEAD"]).unwrap());
let hide_symlinks = location.hide_symlinks.clone().unwrap_or(true);
let mut config = DavConfig::new()
.filesystem(fs)
.strip_prefix(prefix)
.methods(methods)
.hide_symlinks(hide_symlinks)
.autoindex(location.autoindex);
if let Some(auth_user) = auth_user {
config = config.principal(auth_user);
}
if let Some(indexfile) = location.indexfile.clone() {
config = config.indexfile(indexfile);
}
// All set.
self.run_davhandler(config, req).await
}
async fn build_error(&self, code: StatusCode, location: Option<&Location>) -> HttpResult {
let msg = format!(
"<error>{} {}</error>\n",
code.as_u16(),
code.canonical_reason().unwrap_or("")
);
let mut response = self
.response_builder()
.status(code)
.header("Content-Type", "text/xml");
if code == StatusCode::UNAUTHORIZED {
let realm = location.and_then(|location| location.accounts.realm.as_ref());
let realm = realm.or(self.config.accounts.realm.as_ref());
let realm = realm.map(|s| s.as_str()).unwrap_or("Webdav Server");
response = response.header("WWW-Authenticate", format!("Basic realm=\"{}\"", realm).as_str());
}
Ok(response.body(msg.into()).unwrap())
}
async fn auth_error(&self, code: StatusCode, location: &Location) -> HttpResult {
self.build_error(code, Some(location)).await
}
async fn error(&self, code: StatusCode) -> HttpResult {
self.build_error(code, None).await
}
// Call the davhandler, then add headers to the response.
async fn run_davhandler(&self, config: DavConfig, req: HttpRequest) -> HttpResult {
let resp = self.dh.handle_with(config, req).await;
let (mut parts, body) = resp.into_parts();
self.set_headers(&mut parts.headers);
Ok(http::Response::from_parts(parts, body))
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
// command line option processing.
let matches = clap_app!(webdav_server =>
(version: "0.3")
(@arg CFG: -c --config +takes_value "configuration file (/etc/webdav-server.toml)")
(@arg PORT: -p --port +takes_value "listen to this port on localhost only")
(@arg DBG: -D --debug "enable debug level logging")
)
.get_matches();
if matches.is_present("DBG") {
use env_logger::Env;
let level = "webdav_server=debug,webdav_handler=debug";
env_logger::Builder::from_env(Env::default().default_filter_or(level)).init();
} else {
env_logger::init();
}
let port = matches.value_of("PORT");
let cfg = matches.value_of("CFG").unwrap_or("/etc/webdav-server.toml");
// read config.
let mut config = match config::read(cfg.clone()) {
Err(e) => {
eprintln!("{}: {}: {}", PROGNAME, cfg, e);
exit(1);
},
Ok(c) => c,
};
config::check(cfg.clone(), &config);
// build routes.
if let Err(e) = config::build_routes(cfg.clone(), &mut config) {
eprintln!("{}: {}: {}", PROGNAME, cfg, e);
exit(1);
}
if let Some(port) = port {
let localhosts = vec![
("127.0.0.1:".to_string() + port).parse::<SocketAddr>().unwrap(),
("[::]:".to_string() + port).parse::<SocketAddr>().unwrap(),
];
config.server.listen = config::OneOrManyAddr::Many(localhosts);
}
let config = Arc::new(config);
// set cache timeouts.
if let Some(timeout) = config.unix.cache_timeout {
cache::cached::set_pwcache_timeout(timeout);
}
// resolve addresses.
let addrs = config.server.listen.clone().to_socket_addrs().unwrap_or_else(|e| {
eprintln!("{}: {}: [server] listen: {:?}", PROGNAME, cfg, e);
exit(1);
});
let tls_addrs = config.server.tls_listen.clone().to_socket_addrs().unwrap_or_else(|e| {
eprintln!("{}: {}: [server] listen: {:?}", PROGNAME, cfg, e);
exit(1);
});
// initialize auth early.
let auth = auth::Auth::new(config.clone())?;
// start tokio runtime and initialize the rest from within the runtime.
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_io()
.enable_time()
.build()?;
rt.block_on(async move {
// build servers (one for each listen address).
let dav_server = Server::new(config.clone(), auth);
let mut servers = Vec::new();
let mut tls_servers = Vec::new();
// Plaintext servers.
for sockaddr in addrs {
let listener = match make_listener(sockaddr) {
Ok(l) => l,
Err(e) => {
eprintln!("{}: listener on {:?}: {}", PROGNAME, &sockaddr, e);
exit(1);
},
};
let dav_server = dav_server.clone();
let make_service = make_service_fn(move |socket: &AddrStream| {
let dav_server = dav_server.clone();
let remote_addr = socket.remote_addr();
async move {
let func = move |req| {
let dav_server = dav_server.clone();
async move { dav_server.route(req, remote_addr).await }
};
Ok::<_, hyper::Error>(service_fn(func))
}
});
let incoming = AddrIncoming::from_listener(listener)?;
let server = hyper::Server::builder(incoming);
println!("Listening on http://{:?}", sockaddr);
servers.push(async move {
if let Err(e) = server.serve(make_service).await {
eprintln!("{}: server error: {}", PROGNAME, e);
exit(1);
}
});
}
// TLS servers.
if tls_addrs.len() > 0 {
let tls_acceptor = tls_acceptor(&config.server)?;
for sockaddr in tls_addrs {
let tls_acceptor = tls_acceptor.clone();
let listener = make_listener(sockaddr).unwrap_or_else(|e| {
eprintln!("{}: listener on {:?}: {}", PROGNAME, &sockaddr, e);
exit(1);
});
let dav_server = dav_server.clone();
let make_service = make_service_fn(move |stream: &TlsStream<AddrStream>| {
let dav_server = dav_server.clone();
let remote_addr = stream.get_ref().0.remote_addr();
async move {
let func = move |req| {
let dav_server = dav_server.clone();
async move { dav_server.route(req, remote_addr).await }
};
Ok::<_, hyper::Error>(service_fn(func))
}
});
// Since the server can exit when there's an error on the TlsStream,
// we run it in a loop. Every time the loop is entered we dup() the
// listening fd and create a new TcpListener. This way, we should
// not lose any pending connections during a restart.
let master_listen_fd = listener.as_raw_fd();
std::mem::forget(listener);
println!("Listening on https://{:?}", sockaddr);
tls_servers.push(async move {
loop {
// reuse the incoming socket after the server exits.
let listen_fd = match nix::unistd::dup(master_listen_fd) {
Ok(fd) => fd,
Err(e) => {
eprintln!("{}: server error: dup: {}", PROGNAME, e);
break;
}
};
// SAFETY: listen_fd is unique (we just dup'ed it).
let std_listen = unsafe { std::net::TcpListener::from_raw_fd(listen_fd) };
let listener = match tokio::net::TcpListener::from_std(std_listen) {
Ok(l) => l,
Err(e) => {
eprintln!("{}: server error: new TcpListener: {}", PROGNAME, e);
break;
}
};
let a_incoming = match AddrIncoming::from_listener(listener) {
Ok(a) => a,
Err(e) => {
eprintln!("{}: server error: new AddrIncoming: {}", PROGNAME, e);
break;
}
};
let incoming = TlsListener::new(tls_acceptor.clone(), a_incoming);
let server = hyper::Server::builder(incoming);
if let Err(e) = server.serve(make_service.clone()).await {
eprintln!("{}: server error: {} (retrying)", PROGNAME, e);
}
}
});
}
}
// drop privs.
match (&config.server.uid, &config.server.gid) {
(&Some(uid), &Some(gid)) => {
if !suid::have_suid_privs() {
eprintln!(
"{}: insufficent priviliges to switch uid/gid (not root).",
PROGNAME
);
exit(1);
}
let keep_privs = config.location.iter().any(|l| l.setuid);
proc_switch_ugid(uid, gid, keep_privs);
},
_ => {},
}
// spawn all servers, and wait for them to finish.
let mut tasks = Vec::new();
for server in servers.drain(..) {
tasks.push(tokio::spawn(server));
}
for server in tls_servers.drain(..) {
tasks.push(tokio::spawn(server));
}
for task in tasks.drain(..) {
let _ = task.await;
}
Ok::<_, Box<dyn std::error::Error>>(())
})
}
// Clones a http request with an empty body.
fn clone_httpreq(req: &HttpRequest) -> HttpRequest {
let mut builder = http::Request::builder()
.method(req.method().clone())
.uri(req.uri().clone())
.version(req.version().clone());
for (name, value) in req.headers().iter() {
builder = builder.header(name, value);
}
builder.body(hyper::Body::empty()).unwrap()
}
fn expand_directory(dir: &str, pwd: Option<&Arc<unixuser::User>>) -> Result<String, StatusCode> {
// If it doesn't start with "~", skip.
if !dir.starts_with("~") {
return Ok(dir.to_string());
}
// ~whatever doesn't work.
if dir.len() > 1 && !dir.starts_with("~/") {
debug!("expand_directory: rejecting {}", dir);
return Err(StatusCode::NOT_FOUND);
}
// must have a directory, and that dir must be UTF-8.
let pwd = match pwd {
Some(pwd) => pwd,
None => {
debug!("expand_directory: cannot expand {}: no account", dir);
return Err(StatusCode::NOT_FOUND);
},
};
let homedir = pwd.dir.to_str().ok_or(StatusCode::NOT_FOUND)?;
Ok(format!("{}/{}", homedir, &dir[1..]))
}
// Make a new TcpListener, and if it's a V6 listener, set the
// V6_V6ONLY socket option on it.
fn make_listener(addr: SocketAddr) -> io::Result<tokio::net::TcpListener> {
use socket2::{Domain, SockAddr, Socket, Type, Protocol};
let s = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
if addr.is_ipv6() {
s.set_only_v6(true)?;
}
s.set_nonblocking(true)?;
s.set_nodelay(true)?;
s.set_reuse_address(true)?;
let addr: SockAddr = addr.into();
s.bind(&addr)?;
s.listen(128)?;
let listener: std::net::TcpListener = s.into();
tokio::net::TcpListener::from_std(listener)
}

112
src/rootfs.rs Normal file
View File

@@ -0,0 +1,112 @@
//
// Virtual Root filesystem for PROPFIND.
//
// Shows "/" and "/user".
//
use std;
use std::path::Path;
use futures::future::{self, FutureExt};
use webdav_handler::davpath::DavPath;
use webdav_handler::fs::*;
use crate::userfs::UserFs;
#[derive(Clone)]
pub struct RootFs {
user: String,
fs: UserFs,
}
impl RootFs {
pub fn new<P>(dir: P, user: Option<String>, creds: Option<(u32, u32, &[u32])>) -> Box<RootFs>
where P: AsRef<Path> + Clone {
Box::new(RootFs {
user: user.unwrap_or("".to_string()),
fs: *UserFs::new(dir, creds, false, false, true),
})
}
}
impl DavFileSystem for RootFs {
// Only allow "/" or "/user", for both return the metadata of the UserFs root.
fn metadata<'a>(&'a self, path: &'a DavPath) -> FsFuture<Box<dyn DavMetaData>> {
async move {
let b = path.as_bytes();
if b != b"/" && &b[1..] != self.user.as_bytes() {
return Err(FsError::NotFound);
}
let path = DavPath::new("/").unwrap();
self.fs.metadata(&path).await
}
.boxed()
}
// Only return one entry: "user".
fn read_dir<'a>(
&'a self,
path: &'a DavPath,
_meta: ReadDirMeta,
) -> FsFuture<FsStream<Box<dyn DavDirEntry>>>
{
Box::pin(async move {
let mut v = Vec::new();
if self.user != "" {
v.push(RootFsDirEntry {
name: self.user.clone(),
meta: self.fs.metadata(path).await,
});
}
let strm = futures::stream::iter(RootFsReadDir {
iterator: v.into_iter(),
});
Ok(Box::pin(strm) as FsStream<Box<dyn DavDirEntry>>)
})
}
// cannot open any files.
fn open(&self, _path: &DavPath, _options: OpenOptions) -> FsFuture<Box<dyn DavFile>> {
Box::pin(future::ready(Err(FsError::NotImplemented)))
}
// forward quota.
fn get_quota(&self) -> FsFuture<(u64, Option<u64>)> {
self.fs.get_quota()
}
}
#[derive(Debug)]
struct RootFsReadDir {
iterator: std::vec::IntoIter<RootFsDirEntry>,
}
impl Iterator for RootFsReadDir {
type Item = Box<dyn DavDirEntry>;
fn next(&mut self) -> Option<Box<dyn DavDirEntry>> {
match self.iterator.next() {
None => return None,
Some(entry) => Some(Box::new(entry)),
}
}
}
#[derive(Debug)]
struct RootFsDirEntry {
meta: FsResult<Box<dyn DavMetaData>>,
name: String,
}
impl DavDirEntry for RootFsDirEntry {
fn metadata(&self) -> FsFuture<Box<dyn DavMetaData>> {
Box::pin(future::ready(self.meta.clone()))
}
fn name(&self) -> Vec<u8> {
self.name.as_bytes().to_vec()
}
fn is_dir(&self) -> FsFuture<bool> {
Box::pin(future::ready(Ok(true)))
}
}

262
src/router.rs Normal file
View File

@@ -0,0 +1,262 @@
//!
//! Simple and stupid HTTP router.
//!
use std::default::Default;
use std::fmt::Debug;
use lazy_static::lazy_static;
use regex::bytes::{Match, Regex, RegexSet};
use webdav_handler::{DavMethod, DavMethodSet};
// internal representation of a route.
#[derive(Debug)]
struct Route<T: Debug> {
regex: Regex,
methods: Option<DavMethodSet>,
data: T,
}
/// A matched route.
#[derive(Debug)]
pub struct MatchedRoute<'t, 'p, T: Debug> {
pub methods: Option<DavMethodSet>,
pub params: Vec<Option<Param<'p>>>,
pub data: &'t T,
}
/// A parameter on a matched route.
pub struct Param<'p>(Match<'p>);
impl Debug for Param<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("Param")
.field("start", &self.0.start())
.field("end", &self.0.end())
.field("as_str", &std::str::from_utf8(self.0.as_bytes()).ok())
.finish()
}
}
impl<'p> Param<'p> {
/// Returns the starting byte offset of the match in the path.
#[inline]
pub fn start(&self) -> usize {
self.0.start()
}
/// Returns the ending byte offset of the match in the path.
#[inline]
pub fn end(&self) -> usize {
self.0.end()
}
/// Returns the matched part of the path.
#[inline]
pub fn as_bytes(&self) -> &'p [u8] {
self.0.as_bytes()
}
/// Returns the matched part of the path as a &str, if it is valid utf-8.
#[inline]
pub fn as_str(&self) -> Option<&'p str> {
std::str::from_utf8(self.0.as_bytes()).ok()
}
}
pub struct Builder<T: Debug> {
routes: Vec<Route<T>>,
}
impl<T: Debug> Builder<T> {
/// Add a route.
///
/// Routes are matched in the order they were added.
///
/// If a route starts with '^', it's assumed that it is a regular
/// expression. Parameters are included as "named capture groups".
///
/// Otherwise, it's a route-expression, with just the normal :params
/// and *splat param, and parts between parentheses are optional.
///
/// Example:
///
/// - /api/get/:id
/// - /files/*path
/// - /users(/)
/// - /users(/*path)
///
pub fn add(
&mut self,
route: impl AsRef<str>,
methods: Option<DavMethodSet>,
data: T,
) -> Result<&mut Self, regex::Error>
{
let route = route.as_ref();
// Might be a regexp
if route.starts_with("^") {
return self.add_re(route, methods, data);
}
// Ignore it if it does not start with /
if !route.starts_with("/") {
return Ok(self);
}
// First, replace special characters "()*" with unicode chars
// from the private-use area, so that we can then regex-escape
// the entire string.
let re_route = route
.chars()
.map(|c| {
match c {
'*' => '\u{e001}',
'(' => '\u{e002}',
')' => '\u{e003}',
'\u{e001}' => ' ',
'\u{e002}' => ' ',
'\u{e003}' => ' ',
c => c,
}
})
.collect::<String>();
let re_route = regex::escape(&re_route);
// Translate route expression into regexp.
// We do a simple transformation:
// :ident -> (?P<ident>[^/]*)
// *ident -> (?P<ident>.*)
// (text) -> (?:text|)
lazy_static! {
static ref COLON: Regex = Regex::new(":([a-zA-Z0-9]+)").unwrap();
static ref SPLAT: Regex = Regex::new("\u{e001}([a-zA-Z0-9]+)").unwrap();
static ref MAYBE: Regex = Regex::new("\u{e002}([^\u{e002}]*)\u{e003}").unwrap();
};
let mut re_route = re_route.into_bytes();
re_route = COLON.replace_all(&re_route, &b"(?P<$1>[^/]*)"[..]).to_vec();
re_route = SPLAT.replace_all(&re_route, &b"(?P<$1>.*)"[..]).to_vec();
re_route = MAYBE.replace_all(&re_route, &b"($1)?"[..]).to_vec();
// finalize regex.
let re_route = "^".to_string() + &String::from_utf8(re_route).unwrap() + "$";
self.add_re(&re_route, methods, data)
}
// add route as regular expression.
fn add_re(&mut self, s: &str, methods: Option<DavMethodSet>, data: T) -> Result<&mut Self, regex::Error> {
// Set flags: enable ". matches everything", disable strict unicode.
// We known 's' starts with "^", add it after that.
let s2 = format!("^(?s){}", &s[1..]);
let regex = Regex::new(&s2)?;
self.routes.push(Route { regex, methods, data });
Ok(self)
}
/// Combine all the routes and compile them into an internal RegexSet.
pub fn build(&mut self) -> Router<T> {
let set = RegexSet::new(self.routes.iter().map(|r| r.regex.as_str())).unwrap();
Router {
routes: std::mem::replace(&mut self.routes, Vec::new()),
set,
}
}
}
/// Dead simple HTTP router.
#[derive(Debug)]
pub struct Router<T: Debug> {
set: RegexSet,
routes: Vec<Route<T>>,
}
impl<T: Debug> Default for Router<T> {
fn default() -> Router<T> {
Router {
set: RegexSet::new(&[] as &[&str]).unwrap(),
routes: Vec::new(),
}
}
}
impl<T: Debug> Router<T> {
/// Return a builder.
pub fn builder() -> Builder<T> {
Builder { routes: Vec::new() }
}
/// See if the path matches a route in the set.
///
/// The names of the parameters you want to be returned need to be passed in as an array.
pub fn matches<'a>(
&self,
path: &'a [u8],
method: DavMethod,
param_names: &[&str],
) -> Vec<MatchedRoute<'_, 'a, T>>
{
let mut matched = Vec::new();
for idx in self.set.matches(path) {
let route = &self.routes[idx];
if route.methods.map(|m| m.contains(method)).unwrap_or(true) {
let mut params = Vec::new();
if let Some(caps) = route.regex.captures(path) {
for name in param_names {
params.push(caps.name(name).map(|p| Param(p)));
}
} else {
for _ in param_names {
params.push(None);
}
}
matched.push(MatchedRoute {
methods: route.methods,
params,
data: &route.data,
});
}
}
matched
}
}
#[cfg(test)]
mod tests {
use super::*;
use webdav_handler::DavMethod;
fn test_match(rtr: &Router<usize>, p: &[u8], user: &str, path: &str) {
let x = rtr.matches(p, DavMethod::Get, &["user", "path"]);
assert!(x.len() > 0);
let x = &x[0];
if user != "" {
assert!(x.params[0]
.as_ref()
.map(|b| b.as_bytes() == user.as_bytes())
.unwrap_or(false));
}
if path != "" {
assert!(x.params[1]
.as_ref()
.map(|b| b.as_bytes() == path.as_bytes())
.unwrap_or(false));
}
}
#[test]
fn test_router() -> Result<(), Box<dyn std::error::Error>> {
let rtr = Router::<usize>::builder()
.add("/", None, 1)?
.add("/users(/:user)", None, 2)?
.add("/files/*path", None, 3)?
.add("/files(/*path)", None, 4)?
.build();
test_match(&rtr, b"/", "", "");
test_match(&rtr, b"/users", "", "");
test_match(&rtr, b"/users/", "", "");
test_match(&rtr, b"/users/mike", "mike", "");
test_match(&rtr, b"/files/foo/bar", "", "foo/bar");
test_match(&rtr, b"/files", "", "");
Ok(())
}
}

304
src/suid.rs Normal file
View File

@@ -0,0 +1,304 @@
use std::io;
use std::sync::atomic::{AtomicBool, Ordering};
static THREAD_SWITCH_UGID_USED: AtomicBool = AtomicBool::new(false);
#[cfg(all(target_os = "linux"))]
mod setuid {
// On x86, the default SYS_setresuid is 16 bits. We need to
// import the 32-bit variant.
#[cfg(target_arch = "x86")]
mod uid32 {
pub use libc::SYS_getgroups32 as SYS_getgroups;
pub use libc::SYS_setgroups32 as SYS_setgroups;
pub use libc::SYS_setresgid32 as SYS_setresgid;
pub use libc::SYS_setresuid32 as SYS_setresuid;
}
#[cfg(not(target_arch = "x86"))]
mod uid32 {
pub use libc::{SYS_getgroups, SYS_setgroups, SYS_setresgid, SYS_setresuid};
}
use self::uid32::*;
use std::cell::RefCell;
use std::convert::TryInto;
use std::io;
use std::sync::atomic::Ordering;
const ID_NONE: libc::uid_t = 0xffffffff;
// current credentials of this thread.
struct UgidState {
ruid: u32,
euid: u32,
rgid: u32,
egid: u32,
groups: Vec<u32>,
}
impl UgidState {
fn new() -> UgidState {
super::THREAD_SWITCH_UGID_USED.store(true, Ordering::Release);
UgidState {
ruid: unsafe { libc::getuid() } as u32,
euid: unsafe { libc::geteuid() } as u32,
rgid: unsafe { libc::getgid() } as u32,
egid: unsafe { libc::getegid() } as u32,
groups: getgroups().expect("UgidState::new"),
}
}
}
fn getgroups() -> io::Result<Vec<u32>> {
// get number of groups.
let size = unsafe {
libc::syscall(
SYS_getgroups,
0 as libc::c_int,
std::ptr::null_mut::<libc::gid_t>(),
)
};
if size < 0 {
return Err(oserr(size, "getgroups(0, NULL)"));
}
// get groups.
let mut groups = Vec::<u32>::with_capacity(size as usize);
groups.resize(size as usize, 0);
let res = unsafe { libc::syscall(SYS_getgroups, size as libc::c_int, groups.as_mut_ptr() as *mut _) };
// sanity check.
if res != size {
if res < 0 {
return Err(oserr(res, format!("getgroups({}, buffer)", size)));
}
return Err(io::Error::new(
io::ErrorKind::Other,
format!("getgroups({}, buffer): returned {}", size, res),
));
}
Ok(groups)
}
fn oserr(code: libc::c_long, msg: impl AsRef<str>) -> io::Error {
let msg = msg.as_ref();
let err = io::Error::from_raw_os_error(code.try_into().unwrap());
io::Error::new(err.kind(), format!("{}: {}", msg, err))
}
// thread-local seteuid.
fn seteuid(uid: u32) -> io::Result<()> {
let res = unsafe { libc::syscall(SYS_setresuid, ID_NONE, uid, ID_NONE) };
if res < 0 {
return Err(oserr(res, format!("seteuid({})", uid)));
}
Ok(())
}
// thread-local setegid.
fn setegid(gid: u32) -> io::Result<()> {
let res = unsafe { libc::syscall(SYS_setresgid, ID_NONE, gid, ID_NONE) };
if res < 0 {
return Err(oserr(res, format!("setegid({})", gid)));
}
Ok(())
}
// thread-local setgroups.
fn setgroups(gids: &[u32]) -> io::Result<()> {
let size = gids.len() as libc::c_int;
let res = unsafe { libc::syscall(SYS_setgroups, size, gids.as_ptr() as *const libc::gid_t) };
if res < 0 {
return Err(oserr(res, format!("setgroups({}, {:?}", size, gids)));
}
Ok(())
}
// credential state is thread-local.
thread_local!(static CURRENT_UGID: RefCell<UgidState> = RefCell::new(UgidState::new()));
/// Switch thread credentials.
pub(super) fn thread_switch_ugid(newuid: u32, newgid: u32, newgroups: &[u32]) -> (u32, u32, Vec<u32>) {
CURRENT_UGID.with(|current_ugid| {
let mut cur = current_ugid.borrow_mut();
let (olduid, oldgid, oldgroups) = (cur.euid, cur.egid, cur.groups.clone());
let groups_changed = newgroups != cur.groups.as_slice();
// Check if anything changed.
if newuid != cur.euid || newgid != cur.egid || groups_changed {
// See if we have to switch to root privs first.
if cur.euid != 0 && (newuid != cur.ruid || newgid != cur.rgid || groups_changed) {
// Must first switch to root.
if let Err(e) = seteuid(0) {
panic!("{}", e);
}
cur.euid = 0;
}
if newgid != cur.egid {
// Change gid.
if let Err(e) = setegid(newgid) {
panic!("{}", e);
}
cur.egid = newgid;
}
if groups_changed {
// Change groups.
if let Err(e) = setgroups(newgroups) {
panic!("{}", e);
}
cur.groups.truncate(0);
cur.groups.extend_from_slice(newgroups);
}
if newuid != cur.euid {
// Change uid.
if let Err(e) = seteuid(newuid) {
panic!("{}", e);
}
cur.euid = newuid;
}
}
(olduid, oldgid, oldgroups)
})
}
// Yep..
pub fn has_thread_switch_ugid() -> bool {
true
}
}
#[cfg(not(target_os = "linux"))]
mod setuid {
// Not implemented, as it looks like only Linux has support for
// per-thread uid/gid switching.
//
// DO NOT implement this through libc::setuid, as that will
// switch the uids of all threads.
//
/// Switch thread credentials. Not implemented!
pub(super) fn thread_switch_ugid(_newuid: u32, _newgid: u32, _newgroups: &[u32]) -> (u32, u32, Vec<u32>) {
unimplemented!();
}
// Nope.
pub fn has_thread_switch_ugid() -> bool {
false
}
}
pub use self::setuid::has_thread_switch_ugid;
use self::setuid::thread_switch_ugid;
#[derive(Clone, Debug)]
struct UgidCreds {
pub uid: u32,
pub gid: u32,
pub groups: Vec<u32>,
}
pub struct UgidSwitch {
target_creds: Option<UgidCreds>,
}
pub struct UgidSwitchGuard {
base_creds: Option<UgidCreds>,
}
impl UgidSwitch {
pub fn new(creds: Option<(u32, u32, &[u32])>) -> UgidSwitch {
let target_creds = match creds {
Some((uid, gid, groups)) => {
Some(UgidCreds {
uid,
gid,
groups: groups.into(),
})
},
None => None,
};
UgidSwitch { target_creds }
}
#[allow(dead_code)]
pub fn run<F, R>(&self, func: F) -> R
where F: FnOnce() -> R {
let _guard = self.guard();
func()
}
pub fn guard(&self) -> UgidSwitchGuard {
match &self.target_creds {
&None => UgidSwitchGuard { base_creds: None },
&Some(ref creds) => {
let (uid, gid, groups) = thread_switch_ugid(creds.uid, creds.gid, &creds.groups);
UgidSwitchGuard {
base_creds: Some(UgidCreds { uid, gid, groups }),
}
},
}
}
}
impl Drop for UgidSwitchGuard {
fn drop(&mut self) {
if let Some(ref creds) = self.base_creds {
thread_switch_ugid(creds.uid, creds.gid, &creds.groups);
}
}
}
/// Switch process credentials. Keeps the saved-uid as root, so that
/// we can switch to other ids later on.
pub fn proc_switch_ugid(uid: u32, gid: u32, keep_privs: bool) {
if THREAD_SWITCH_UGID_USED.load(Ordering::Acquire) {
panic!("proc_switch_ugid: called after thread_switch_ugid() has been used");
}
fn last_os_error() -> io::Error {
io::Error::last_os_error()
}
unsafe {
// first get full root privs (real, effective, and saved uids)
if libc::setuid(0) != 0 {
panic!("libc::setuid(0): {:?}", last_os_error());
}
// set real uid, and keep effective uid at 0.
#[cfg(not(any(target_os = "openbsd", target_os = "freebsd")))]
if libc::setreuid(uid, 0) != 0 {
panic!("libc::setreuid({}, 0): {:?}", uid, last_os_error());
}
#[cfg(any(target_os = "openbsd", target_os = "freebsd"))]
if libc::setresuid(uid, 0, 0) != 0 {
panic!("libc::setreuid({}, 0): {:?}", uid, last_os_error());
}
// set group id.
if libc::setgid(gid) != 0 {
panic!("libc::setgid({}): {:?}", gid, last_os_error());
}
// remove _all_ auxilary groups.
if libc::setgroups(0, std::ptr::null::<libc::gid_t>()) != 0 {
panic!("setgroups[]: {:?}", last_os_error());
}
if keep_privs {
// finally set effective uid. saved uid is still 0.
if libc::seteuid(uid) != 0 {
panic!("libc::seteuid({}): {:?}", uid, last_os_error());
}
} else {
// drop all privs.
if libc::setuid(uid) != 0 {
panic!("libc::setuid({}): {:?}", uid, last_os_error());
}
}
}
}
/// Do we have sufficient privs to switch uids?
pub fn have_suid_privs() -> bool {
unsafe { libc::geteuid() == 0 }
}

58
src/tls.rs Normal file
View File

@@ -0,0 +1,58 @@
use std::fs::File;
use std::io::{self, ErrorKind};
use std::sync::Arc;
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig};
use tokio_rustls::TlsAcceptor;
use rustls_pemfile as pemfile;
use crate::config::Server;
pub fn tls_acceptor(cfg: &Server) -> io::Result<TlsAcceptor> {
// Private key.
let pkey_fn = cfg.tls_key.as_ref().ok_or_else(|| {
io::Error::new(io::ErrorKind::NotFound, "config: server: tls_key not set")
})?;
let pkey_file = File::open(pkey_fn).map_err(|e| {
io::Error::new(e.kind(), format!("{}: {}", pkey_fn, e))
})?;
let mut pkey_file = io::BufReader::new(pkey_file);
let pkey = match pemfile::read_one(&mut pkey_file) {
Ok(Some(pemfile::Item::RSAKey(pkey))) => PrivateKey(pkey),
Ok(Some(pemfile::Item::PKCS8Key(pkey))) => PrivateKey(pkey),
Ok(Some(pemfile::Item::ECKey(pkey))) => PrivateKey(pkey),
Ok(Some(_)) => return Err(io::Error::new(io::ErrorKind::InvalidData, format!("{}: unknown private key format", pkey_fn))),
Ok(None) => return Err(io::Error::new(io::ErrorKind::InvalidData, format!("{}: expected one private key", pkey_fn))),
Err(_) => return Err(io::Error::new(io::ErrorKind::InvalidData, format!("{}: invalid data", pkey_fn))),
};
// Certificate.
let cert_fn = cfg.tls_cert.as_ref().ok_or_else(|| {
io::Error::new(io::ErrorKind::NotFound, "config: server: tls_cert not set")
})?;
let cert_file = File::open(cert_fn).map_err(|e| {
io::Error::new(e.kind(), format!("{}: {}", cert_fn, e))
})?;
let mut cert_file = io::BufReader::new(cert_file);
let certs = pemfile::certs(&mut cert_file).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidData, format!("{}: invalid data", cert_fn))
})?;
let certs = certs
.into_iter()
.map(|cert| Certificate(cert.into()))
.collect();
let config = Arc::new(
ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, pkey)
.map_err(|e| {
io::Error::new(ErrorKind::InvalidData, format!("{}/{}: {}", pkey_fn, cert_fn, e))
})?
).into();
Ok(config)
}

135
src/unixuser.rs Normal file
View File

@@ -0,0 +1,135 @@
use std;
use std::ffi::{CStr, OsStr};
use std::io;
use std::os::unix::ffi::OsStrExt;
use std::path::{Path, PathBuf};
use tokio::task::block_in_place;
#[derive(Debug)]
pub struct User {
pub name: String,
pub passwd: String,
pub gecos: String,
pub uid: u32,
pub gid: u32,
pub groups: Vec<u32>,
pub dir: PathBuf,
pub shell: PathBuf,
}
unsafe fn cptr_to_osstr<'a>(c: *const libc::c_char) -> &'a OsStr {
let bytes = CStr::from_ptr(c).to_bytes();
OsStr::from_bytes(&bytes)
}
unsafe fn cptr_to_path<'a>(c: *const libc::c_char) -> &'a Path {
Path::new(cptr_to_osstr(c))
}
unsafe fn to_user(pwd: &libc::passwd) -> User {
// turn into (unsafe!) rust slices
let cs_name = CStr::from_ptr(pwd.pw_name);
let cs_passwd = CStr::from_ptr(pwd.pw_passwd);
let cs_gecos = CStr::from_ptr(pwd.pw_gecos);
let cs_dir = cptr_to_path(pwd.pw_dir);
let cs_shell = cptr_to_path(pwd.pw_shell);
// then turn the slices into safe owned values.
User {
name: cs_name.to_string_lossy().into_owned(),
passwd: cs_passwd.to_string_lossy().into_owned(),
gecos: cs_gecos.to_string_lossy().into_owned(),
dir: cs_dir.to_path_buf(),
shell: cs_shell.to_path_buf(),
uid: pwd.pw_uid,
gid: pwd.pw_gid,
groups: Vec::new(),
}
}
impl User {
pub fn by_name(name: &str, with_groups: bool) -> Result<User, io::Error> {
let mut buf = [0u8; 1024];
let mut pwd: libc::passwd = unsafe { std::mem::zeroed() };
let mut result: *mut libc::passwd = std::ptr::null_mut();
let cname = match std::ffi::CString::new(name) {
Ok(un) => un,
Err(_) => return Err(io::Error::from_raw_os_error(libc::ENOENT)),
};
let ret = unsafe {
libc::getpwnam_r(
cname.as_ptr(),
&mut pwd as *mut _,
buf.as_mut_ptr() as *mut _,
buf.len() as libc::size_t,
&mut result as *mut _,
)
};
if ret != 0 {
return Err(io::Error::from_raw_os_error(ret));
}
if result.is_null() {
return Err(io::Error::from_raw_os_error(libc::ENOENT));
}
let mut user = unsafe { to_user(&pwd) };
if with_groups {
let mut ngroups = (buf.len() / std::mem::size_of::<libc::gid_t>()) as libc::c_int;
let ret = unsafe {
libc::getgrouplist(
cname.as_ptr(),
user.gid as libc::gid_t,
buf.as_mut_ptr() as *mut _,
&mut ngroups as *mut _,
)
};
if ret >= 0 && ngroups > 0 {
let mut groups_vec = Vec::with_capacity(ngroups as usize);
let groups = unsafe {
std::slice::from_raw_parts(buf.as_ptr() as *const libc::gid_t, ngroups as usize)
};
//
// Only supplementary or auxilary groups, filter out primary.
//
groups_vec.extend(groups.iter().map(|&g| g as u32).filter(|&g| g != user.gid));
user.groups = groups_vec;
}
}
Ok(user)
}
/*
pub fn by_uid(uid: u32) -> Result<User, io::Error> {
let mut buf = [0; 1024];
let mut pwd: libc::passwd = unsafe { std::mem::zeroed() };
let mut result: *mut libc::passwd = std::ptr::null_mut();
let ret = unsafe {
getpwuid_r(
uid,
&mut pwd as *mut _,
buf.as_mut_ptr(),
buf.len() as libc::size_t,
&mut result as *mut _,
)
};
if ret == 0 {
if result.is_null() {
return Err(io::Error::from_raw_os_error(libc::ENOENT));
}
let p = unsafe { to_user(&pwd) };
Ok(p)
} else {
Err(io::Error::from_raw_os_error(ret))
}
}
*/
pub async fn by_name_async(name: &str, with_groups: bool) -> Result<User, io::Error> {
block_in_place(move || User::by_name(name, with_groups))
}
}

125
src/userfs.rs Normal file
View File

@@ -0,0 +1,125 @@
use std::any::Any;
use std::path::{Path, PathBuf};
use webdav_handler::davpath::DavPath;
use webdav_handler::fs::*;
use webdav_handler::localfs::LocalFs;
use crate::suid::UgidSwitch;
#[derive(Clone)]
pub struct UserFs {
pub fs: LocalFs,
basedir: PathBuf,
uid: u32,
}
impl UserFs {
pub fn new(
dir: impl AsRef<Path>,
target_creds: Option<(u32, u32, &[u32])>,
public: bool,
case_insensitive: bool,
macos: bool,
) -> Box<UserFs>
{
// uid is used for quota() calls.
let uid = target_creds.as_ref().map(|ugid| ugid.0).unwrap_or(0);
// set up the LocalFs hooks for uid switching.
let switch = UgidSwitch::new(target_creds.clone());
let blocking_guard = Box::new(move || Box::new(switch.guard()) as Box<dyn Any>);
Box::new(UserFs {
basedir: dir.as_ref().to_path_buf(),
fs: *LocalFs::new_with_fs_access_guard(
dir,
public,
case_insensitive,
macos,
Some(blocking_guard),
),
uid: uid,
})
}
}
impl DavFileSystem for UserFs {
fn metadata<'a>(&'a self, path: &'a DavPath) -> FsFuture<Box<dyn DavMetaData>> {
self.fs.metadata(path)
}
fn symlink_metadata<'a>(&'a self, path: &'a DavPath) -> FsFuture<Box<dyn DavMetaData>> {
self.fs.symlink_metadata(path)
}
fn read_dir<'a>(
&'a self,
path: &'a DavPath,
meta: ReadDirMeta,
) -> FsFuture<FsStream<Box<dyn DavDirEntry>>>
{
self.fs.read_dir(path, meta)
}
fn open<'a>(&'a self, path: &'a DavPath, options: OpenOptions) -> FsFuture<Box<dyn DavFile>> {
self.fs.open(path, options)
}
fn create_dir<'a>(&'a self, path: &'a DavPath) -> FsFuture<()> {
self.fs.create_dir(path)
}
fn remove_dir<'a>(&'a self, path: &'a DavPath) -> FsFuture<()> {
self.fs.remove_dir(path)
}
fn remove_file<'a>(&'a self, path: &'a DavPath) -> FsFuture<()> {
self.fs.remove_file(path)
}
fn rename<'a>(&'a self, from: &'a DavPath, to: &'a DavPath) -> FsFuture<()> {
self.fs.rename(from, to)
}
fn copy<'a>(&'a self, from: &'a DavPath, to: &'a DavPath) -> FsFuture<()> {
self.fs.copy(from, to)
}
#[cfg(feature = "quota")]
fn get_quota<'a>(&'a self) -> FsFuture<(u64, Option<u64>)> {
use crate::cache;
use fs_quota::*;
use futures::future::FutureExt;
use std::time::Duration;
lazy_static::lazy_static! {
static ref QCACHE: cache::Cache<PathBuf, FsQuota> = cache::Cache::new().maxage(Duration::new(30, 0));
}
async move {
let mut key = self.basedir.clone();
key.push(&self.uid.to_string());
let r = match QCACHE.get(&key) {
Some(r) => {
debug!("get_quota for {:?}: from cache", key);
r
},
None => {
let path = self.basedir.clone();
let uid = self.uid;
let r = self
.fs
.blocking(move || {
FsQuota::check(&path, Some(uid)).map_err(|_| FsError::GeneralFailure)
})
.await?;
debug!("get_quota for {:?}: insert to cache", key);
QCACHE.insert(key, r)
},
};
Ok((r.bytes_used, r.bytes_limit))
}
.boxed()
}
}