feat: clone from https://github.com/miquels/webdav-server-rs
This commit is contained in:
156
src/auth.rs
Normal file
156
src/auth.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::{AuthType, Config, Location};
|
||||
|
||||
use headers::{authorization::Basic, Authorization, HeaderMapExt};
|
||||
use http::status::StatusCode;
|
||||
|
||||
type HttpRequest = http::Request<hyper::Body>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Auth {
|
||||
config: Arc<Config>,
|
||||
#[cfg(feature = "pam")]
|
||||
pam_auth: pam_sandboxed::PamAuth,
|
||||
}
|
||||
|
||||
impl Auth {
|
||||
pub fn new(config: Arc<Config>) -> io::Result<Auth> {
|
||||
// initialize pam.
|
||||
#[cfg(feature = "pam")]
|
||||
let pam_auth = {
|
||||
// set cache timeouts.
|
||||
if let Some(timeout) = config.pam.cache_timeout {
|
||||
crate::cache::cached::set_pamcache_timeout(timeout);
|
||||
}
|
||||
pam_sandboxed::PamAuth::new(config.pam.threads.clone())?
|
||||
};
|
||||
|
||||
Ok(Auth {
|
||||
#[cfg(feature = "pam")]
|
||||
pam_auth,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
// authenticate user.
|
||||
pub async fn auth<'a>(
|
||||
&'a self,
|
||||
req: &'a HttpRequest,
|
||||
location: &Location,
|
||||
_remote_ip: SocketAddr,
|
||||
) -> Result<String, StatusCode>
|
||||
{
|
||||
// we must have a login/pass
|
||||
let basic = match req.headers().typed_get::<Authorization<Basic>>() {
|
||||
Some(Authorization(basic)) => basic,
|
||||
_ => return Err(StatusCode::UNAUTHORIZED),
|
||||
};
|
||||
let user = basic.username();
|
||||
let pass = basic.password();
|
||||
|
||||
// match the auth type.
|
||||
let auth_type = location
|
||||
.accounts
|
||||
.auth_type
|
||||
.as_ref()
|
||||
.or(self.config.accounts.auth_type.as_ref());
|
||||
match auth_type {
|
||||
#[cfg(feature = "pam")]
|
||||
Some(&AuthType::Pam) => self.auth_pam(req, user, pass, _remote_ip).await,
|
||||
Some(&AuthType::HtPasswd(ref ht)) => self.auth_htpasswd(user, pass, ht.as_str()).await,
|
||||
None => {
|
||||
debug!("need authentication, but auth-type is not set");
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// authenticate user using PAM.
|
||||
#[cfg(feature = "pam")]
|
||||
async fn auth_pam<'a>(
|
||||
&'a self,
|
||||
req: &'a HttpRequest,
|
||||
user: &'a str,
|
||||
pass: &'a str,
|
||||
remote_ip: SocketAddr,
|
||||
) -> Result<String, StatusCode>
|
||||
{
|
||||
// stringify the remote IP address.
|
||||
let ip = remote_ip.ip();
|
||||
let ip_string = if ip.is_loopback() {
|
||||
// if it's loopback, take the value from the x-forwarded-for
|
||||
// header, if present.
|
||||
req.headers()
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|s| s.to_str().ok())
|
||||
.and_then(|s| s.split(',').next())
|
||||
.map(|s| s.trim().to_owned())
|
||||
} else {
|
||||
Some(match ip {
|
||||
std::net::IpAddr::V4(ip) => ip.to_string(),
|
||||
std::net::IpAddr::V6(ip) => ip.to_string(),
|
||||
})
|
||||
};
|
||||
let ip_ref = ip_string.as_ref().map(|s| s.as_str());
|
||||
|
||||
// authenticate.
|
||||
let service = self.config.pam.service.as_str();
|
||||
let pam_auth = self.pam_auth.clone();
|
||||
match crate::cache::cached::pam_auth(pam_auth, service, user, pass, ip_ref).await {
|
||||
Ok(_) => Ok(user.to_string()),
|
||||
Err(_) => {
|
||||
debug!(
|
||||
"auth_pam({}): authentication for {} ({:?}) failed",
|
||||
service, user, ip_ref
|
||||
);
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// authenticate user using htpasswd.
|
||||
async fn auth_htpasswd<'a>(
|
||||
&'a self,
|
||||
user: &'a str,
|
||||
pass: &'a str,
|
||||
section: &'a str,
|
||||
) -> Result<String, StatusCode>
|
||||
{
|
||||
// Get the htpasswd.WHATEVER section from the config file.
|
||||
let file = match self.config.htpasswd.get(section) {
|
||||
Some(section) => section.htpasswd.as_str(),
|
||||
None => return Err(StatusCode::UNAUTHORIZED),
|
||||
};
|
||||
|
||||
// Read the file and split it into a bunch of lines.
|
||||
tokio::task::block_in_place(move || {
|
||||
let data = match std::fs::read_to_string(file) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
debug!("{}: {}", file, e);
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
},
|
||||
};
|
||||
let lines = data
|
||||
.split('\n')
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.starts_with("#") && !s.is_empty());
|
||||
|
||||
// Check each line for a match.
|
||||
for line in lines {
|
||||
let mut fields = line.split(':');
|
||||
if let (Some(htuser), Some(htpass)) = (fields.next(), fields.next()) {
|
||||
if htuser == user && pwhash::unix::verify(pass, htpass) {
|
||||
return Ok(user.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("auth_htpasswd: authentication for {} failed", user);
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
})
|
||||
}
|
||||
}
|
||||
185
src/cache.rs
Normal file
185
src/cache.rs
Normal file
@@ -0,0 +1,185 @@
|
||||
use std::borrow::Borrow;
|
||||
use std::cmp::Eq;
|
||||
use std::collections::vec_deque::VecDeque;
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
use std::option::Option;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct Cache<K, V> {
|
||||
intern: Mutex<Intern<K, V>>,
|
||||
}
|
||||
|
||||
struct Intern<K, V> {
|
||||
maxsize: usize,
|
||||
maxage: Duration,
|
||||
map: HashMap<K, Arc<V>>,
|
||||
fifo: VecDeque<(Instant, K)>,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone, V> Cache<K, V> {
|
||||
pub fn new() -> Cache<K, V> {
|
||||
let i = Intern {
|
||||
maxsize: 0,
|
||||
maxage: Duration::new(0, 0),
|
||||
map: HashMap::new(),
|
||||
fifo: VecDeque::new(),
|
||||
};
|
||||
Cache {
|
||||
intern: Mutex::new(i),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn maxsize(self, maxsize: usize) -> Self {
|
||||
self.intern.lock().unwrap().maxsize = maxsize;
|
||||
self
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn maxage(self, maxage: Duration) -> Self {
|
||||
self.intern.lock().unwrap().maxage = maxage;
|
||||
self
|
||||
}
|
||||
|
||||
fn expire(&self, m: &mut Intern<K, V>) {
|
||||
let mut n = m.fifo.len();
|
||||
if m.maxsize > 0 && n >= m.maxsize {
|
||||
n = m.maxsize;
|
||||
}
|
||||
if m.maxage.as_secs() > 0 || m.maxage.subsec_nanos() > 0 {
|
||||
let now = Instant::now();
|
||||
while n > 0 {
|
||||
let &(t, _) = m.fifo.get(n - 1).unwrap();
|
||||
if now.duration_since(t) <= m.maxage {
|
||||
break;
|
||||
}
|
||||
n -= 1;
|
||||
}
|
||||
}
|
||||
for x in n..m.fifo.len() {
|
||||
let &(_, ref key) = m.fifo.get(x).unwrap();
|
||||
m.map.remove(&key);
|
||||
}
|
||||
m.fifo.truncate(n);
|
||||
}
|
||||
|
||||
pub fn insert(&self, key: K, val: V) -> Arc<V> {
|
||||
let mut m = self.intern.lock().unwrap();
|
||||
self.expire(&mut *m);
|
||||
let av = Arc::new(val);
|
||||
let ac = av.clone();
|
||||
m.map.insert(key.clone(), av);
|
||||
m.fifo.push_front((Instant::now(), key));
|
||||
ac
|
||||
}
|
||||
|
||||
// see https://doc.rust-lang.org/book/first-edition/borrow-and-asref.html
|
||||
pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<Arc<V>>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq,
|
||||
{
|
||||
let mut m = self.intern.lock().unwrap();
|
||||
self.expire(&mut *m);
|
||||
if let Some(v) = m.map.get(key) {
|
||||
return Some(v.clone());
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod cached {
|
||||
//
|
||||
// Cached versions of Unix account lookup and Pam auth.
|
||||
//
|
||||
use std::io;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::cache;
|
||||
use crate::unixuser::{self, User};
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
struct Timeouts {
|
||||
pwcache: Duration,
|
||||
pamcache: Duration,
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref TIMEOUTS: Mutex<Timeouts> = Mutex::new(Timeouts {
|
||||
pwcache: Duration::new(120, 0),
|
||||
pamcache: Duration::new(120, 0),
|
||||
});
|
||||
static ref PWCACHE: cache::Cache<String, unixuser::User> = new_pwcache();
|
||||
static ref PAMCACHE: cache::Cache<u64, String> = new_pamcache();
|
||||
}
|
||||
|
||||
fn new_pwcache() -> cache::Cache<String, unixuser::User> {
|
||||
let timeouts = TIMEOUTS.lock().unwrap();
|
||||
cache::Cache::new().maxage(timeouts.pwcache)
|
||||
}
|
||||
|
||||
fn new_pamcache() -> cache::Cache<u64, String> {
|
||||
let timeouts = TIMEOUTS.lock().unwrap();
|
||||
cache::Cache::new().maxage(timeouts.pamcache)
|
||||
}
|
||||
|
||||
pub(crate) fn set_pwcache_timeout(secs: usize) {
|
||||
let mut timeouts = TIMEOUTS.lock().unwrap();
|
||||
timeouts.pwcache = Duration::new(secs as u64, 0);
|
||||
}
|
||||
|
||||
#[cfg(feature = "pam")]
|
||||
pub(crate) fn set_pamcache_timeout(secs: usize) {
|
||||
let mut timeouts = TIMEOUTS.lock().unwrap();
|
||||
timeouts.pamcache = Duration::new(secs as u64, 0);
|
||||
}
|
||||
|
||||
#[cfg(feature = "pam")]
|
||||
pub async fn pam_auth<'a>(
|
||||
pam_auth: pam_sandboxed::PamAuth,
|
||||
service: &'a str,
|
||||
user: &'a str,
|
||||
pass: &'a str,
|
||||
remip: Option<&'a str>,
|
||||
) -> Result<(), pam_sandboxed::PamError>
|
||||
{
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
let mut s = DefaultHasher::new();
|
||||
service.hash(&mut s);
|
||||
user.hash(&mut s);
|
||||
pass.hash(&mut s);
|
||||
remip.as_ref().hash(&mut s);
|
||||
let key = s.finish();
|
||||
|
||||
if let Some(cache_user) = PAMCACHE.get(&key) {
|
||||
if user == cache_user.as_str() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let mut pam_auth = pam_auth;
|
||||
match pam_auth.auth(&service, &user, &pass, remip).await {
|
||||
Err(e) => Err(e),
|
||||
Ok(()) => {
|
||||
PAMCACHE.insert(key, user.to_owned());
|
||||
Ok(())
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn unixuser(username: &str, with_groups: bool) -> Result<Arc<User>, io::Error> {
|
||||
if let Some(pwd) = PWCACHE.get(username) {
|
||||
return Ok(pwd);
|
||||
}
|
||||
match User::by_name_async(username, with_groups).await {
|
||||
Err(e) => Err(e),
|
||||
Ok(pwd) => Ok(PWCACHE.insert(username.to_owned(), pwd)),
|
||||
}
|
||||
}
|
||||
}
|
||||
339
src/config.rs
Normal file
339
src/config.rs
Normal file
@@ -0,0 +1,339 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
use std::path::Path;
|
||||
use std::process::exit;
|
||||
use std::{fs, io};
|
||||
|
||||
use enum_from_str::ParseEnumVariantError;
|
||||
use enum_from_str_derive::FromStr;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use toml;
|
||||
use webdav_handler::DavMethodSet;
|
||||
|
||||
use crate::router::Router;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct Config {
|
||||
pub server: Server,
|
||||
#[serde(default)]
|
||||
pub accounts: Accounts,
|
||||
#[serde(default)]
|
||||
pub pam: Pam,
|
||||
#[serde(default)]
|
||||
pub htpasswd: HashMap<String, HtPasswd>,
|
||||
#[serde(default)]
|
||||
pub unix: Unix,
|
||||
#[serde(default)]
|
||||
pub location: Vec<Location>,
|
||||
#[serde(skip)]
|
||||
pub router: Router<usize>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct Server {
|
||||
#[serde(default)]
|
||||
pub listen: OneOrManyAddr,
|
||||
#[serde(default)]
|
||||
pub tls_listen: OneOrManyAddr,
|
||||
#[serde(default)]
|
||||
pub tls_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub tls_cert: Option<String>,
|
||||
//#[serde(deserialize_with = "deserialize_user", default)]
|
||||
pub uid: Option<u32>,
|
||||
//#[serde(deserialize_with = "deserialize_group", default)]
|
||||
pub gid: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub identification: Option<String>,
|
||||
#[serde(default)]
|
||||
pub cors: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Default)]
|
||||
pub struct Accounts {
|
||||
#[serde(rename = "auth-type", deserialize_with = "deserialize_authtype", default)]
|
||||
pub auth_type: Option<AuthType>,
|
||||
#[serde(rename = "acct-type", deserialize_with = "deserialize_opt_enum", default)]
|
||||
pub acct_type: Option<AcctType>,
|
||||
#[serde(default)]
|
||||
pub realm: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Default)]
|
||||
pub struct Pam {
|
||||
pub service: String,
|
||||
#[serde(rename = "cache-timeout")]
|
||||
pub cache_timeout: Option<usize>,
|
||||
pub threads: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Default)]
|
||||
pub struct HtPasswd {
|
||||
pub htpasswd: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Default)]
|
||||
pub struct Unix {
|
||||
#[serde(rename = "cache-timeout")]
|
||||
pub cache_timeout: Option<usize>,
|
||||
#[serde(rename = "min-uid", default)]
|
||||
pub min_uid: Option<u32>,
|
||||
#[serde(rename = "supplementary-groups", default)]
|
||||
pub aux_groups: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct Location {
|
||||
#[serde(default)]
|
||||
pub route: Vec<String>,
|
||||
#[serde(deserialize_with = "deserialize_methodset", default)]
|
||||
pub methods: Option<DavMethodSet>,
|
||||
#[serde(deserialize_with = "deserialize_opt_enum", default)]
|
||||
pub auth: Option<Auth>,
|
||||
#[serde(default, flatten)]
|
||||
pub accounts: Accounts,
|
||||
#[serde(deserialize_with = "deserialize_enum")]
|
||||
pub handler: Handler,
|
||||
#[serde(default)]
|
||||
pub setuid: bool,
|
||||
pub directory: String,
|
||||
#[serde(default, alias = "hide-symlinks")]
|
||||
pub hide_symlinks: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub indexfile: Option<String>,
|
||||
#[serde(default)]
|
||||
pub autoindex: bool,
|
||||
#[serde(
|
||||
rename = "case-insensitive",
|
||||
deserialize_with = "deserialize_opt_enum",
|
||||
default
|
||||
)]
|
||||
pub case_insensitive: Option<CaseInsensitive>,
|
||||
#[serde(deserialize_with = "deserialize_opt_enum", default)]
|
||||
pub on_notfound: Option<OnNotfound>,
|
||||
}
|
||||
|
||||
#[derive(FromStr, Debug, Clone, Copy)]
|
||||
pub enum Handler {
|
||||
#[from_str = "virtroot"]
|
||||
Virtroot,
|
||||
#[from_str = "filesystem"]
|
||||
Filesystem,
|
||||
}
|
||||
|
||||
#[derive(FromStr, Debug, Clone, Copy)]
|
||||
pub enum Auth {
|
||||
#[from_str = "false"]
|
||||
False,
|
||||
#[from_str = "true"]
|
||||
True,
|
||||
#[from_str = "opportunistic"]
|
||||
Opportunistic,
|
||||
#[from_str = "write"]
|
||||
Write,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AuthType {
|
||||
#[cfg(feature = "pam")]
|
||||
Pam,
|
||||
HtPasswd(String),
|
||||
}
|
||||
|
||||
#[derive(FromStr, Debug, Clone, Copy)]
|
||||
pub enum AcctType {
|
||||
#[from_str = "unix"]
|
||||
Unix,
|
||||
}
|
||||
|
||||
#[derive(FromStr, Debug, Clone, Copy)]
|
||||
pub enum CaseInsensitive {
|
||||
#[from_str = "true"]
|
||||
True,
|
||||
#[from_str = "ms"]
|
||||
Ms,
|
||||
#[from_str = "false"]
|
||||
False,
|
||||
}
|
||||
|
||||
#[derive(FromStr, Debug, Clone, Copy)]
|
||||
pub enum OnNotfound {
|
||||
#[from_str = "continue"]
|
||||
Continue,
|
||||
#[from_str = "return"]
|
||||
Return,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum OneOrManyAddr {
|
||||
One(SocketAddr),
|
||||
Many(Vec<SocketAddr>),
|
||||
}
|
||||
|
||||
impl OneOrManyAddr {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
OneOrManyAddr::One(_) => false,
|
||||
OneOrManyAddr::Many(v) => v.is_empty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OneOrManyAddr {
|
||||
fn default() -> Self {
|
||||
OneOrManyAddr::Many(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSocketAddrs for OneOrManyAddr {
|
||||
type Iter = std::vec::IntoIter<SocketAddr>;
|
||||
fn to_socket_addrs(&self) -> io::Result<std::vec::IntoIter<SocketAddr>> {
|
||||
let i = match self {
|
||||
OneOrManyAddr::Many(ref v) => v.to_owned(),
|
||||
OneOrManyAddr::One(ref s) => vec![*s],
|
||||
};
|
||||
Ok(i.into_iter())
|
||||
}
|
||||
}
|
||||
|
||||
// keep this here for now, we might implement a enum{(u32, String} later for
|
||||
// usernames and groupnames.
|
||||
#[allow(unused)]
|
||||
pub fn deserialize_user<'de, D>(deserializer: D) -> Result<Option<u32>, D::Error>
|
||||
where D: Deserializer<'de> {
|
||||
let s = String::deserialize(deserializer)?;
|
||||
s.parse::<u32>()
|
||||
.map(|v| Some(v))
|
||||
.map_err(serde::de::Error::custom)
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn deserialize_group<'de, D>(deserializer: D) -> Result<Option<u32>, D::Error>
|
||||
where D: Deserializer<'de> {
|
||||
let s = String::deserialize(deserializer)?;
|
||||
s.parse::<u32>()
|
||||
.map(|v| Some(v))
|
||||
.map_err(serde::de::Error::custom)
|
||||
}
|
||||
|
||||
pub fn deserialize_methodset<'de, D>(deserializer: D) -> Result<Option<DavMethodSet>, D::Error>
|
||||
where D: Deserializer<'de> {
|
||||
let m = Vec::<String>::deserialize(deserializer)?;
|
||||
DavMethodSet::from_vec(m)
|
||||
.map(|v| Some(v))
|
||||
.map_err(serde::de::Error::custom)
|
||||
}
|
||||
|
||||
pub fn deserialize_authtype<'de, D>(deserializer: D) -> Result<Option<AuthType>, D::Error>
|
||||
where D: Deserializer<'de> {
|
||||
let s = String::deserialize(deserializer)?;
|
||||
if s.starts_with("htpasswd.") {
|
||||
return Ok(Some(AuthType::HtPasswd(s[9..].to_string())));
|
||||
}
|
||||
#[cfg(feature = "pam")]
|
||||
if &s == "pam" {
|
||||
return Ok(Some(AuthType::Pam));
|
||||
}
|
||||
if s == "" {
|
||||
return Ok(None);
|
||||
}
|
||||
Err(serde::de::Error::custom("unknown auth-type"))
|
||||
}
|
||||
|
||||
pub fn deserialize_opt_enum<'de, D, E>(deserializer: D) -> Result<Option<E>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
E: std::str::FromStr,
|
||||
E::Err: std::fmt::Display,
|
||||
{
|
||||
String::deserialize(deserializer)?
|
||||
.as_str()
|
||||
.parse::<E>()
|
||||
.map(|e| Some(e))
|
||||
.map_err(serde::de::Error::custom)
|
||||
}
|
||||
|
||||
pub fn deserialize_enum<'de, D, E>(deserializer: D) -> Result<E, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
E: std::str::FromStr,
|
||||
E::Err: std::fmt::Display,
|
||||
{
|
||||
String::deserialize(deserializer)?
|
||||
.as_str()
|
||||
.parse::<E>()
|
||||
.map_err(serde::de::Error::custom)
|
||||
}
|
||||
|
||||
// Read the TOML config into a config::Config struct.
|
||||
pub fn read(toml_file: impl AsRef<Path>) -> io::Result<Config> {
|
||||
let buffer = fs::read_to_string(&toml_file)?;
|
||||
|
||||
// initial parse.
|
||||
let config: Config = match toml::from_str(&buffer) {
|
||||
Ok(v) => Ok(v),
|
||||
Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
|
||||
}?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn build_routes(cfg: &str, config: &mut Config) -> io::Result<()> {
|
||||
let mut builder = Router::builder();
|
||||
for (idx, location) in config.location.iter().enumerate() {
|
||||
for r in &location.route {
|
||||
if let Err(e) = builder.add(r, location.methods.clone(), idx) {
|
||||
let msg = format!("{}: [[location]][{}]: route {}: {}", cfg, idx, r, e);
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
|
||||
}
|
||||
}
|
||||
}
|
||||
config.router = builder.build();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn check(cfg: &str, config: &Config) {
|
||||
#[cfg(feature = "pam")]
|
||||
if let Some(AuthType::Pam) = config.accounts.auth_type {
|
||||
if config.pam.service == "" {
|
||||
eprintln!("{}: missing section [pam]", cfg);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if config.server.listen.is_empty() && config.server.tls_listen.is_empty() {
|
||||
eprintln!("{}: [server]: at least one of listen or tls_listen must be set", cfg);
|
||||
exit(1);
|
||||
}
|
||||
if !config.server.tls_listen.is_empty() {
|
||||
if config.server.tls_cert.is_none() {
|
||||
eprintln!("{}: [server]: tls_cert not set", cfg);
|
||||
exit(1);
|
||||
}
|
||||
if config.server.tls_key.is_none() {
|
||||
eprintln!("{}: [server]: tls_key not set", cfg);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
for (idx, location) in config.location.iter().enumerate() {
|
||||
if location.setuid {
|
||||
if !crate::suid::has_thread_switch_ugid() {
|
||||
eprintln!(
|
||||
"{}: [[location]][{}]: setuid: uid switching not supported on this OS",
|
||||
cfg, idx
|
||||
);
|
||||
exit(1);
|
||||
}
|
||||
if config.server.uid.is_none() || config.server.gid.is_none() {
|
||||
eprintln!("{}: [server]: missing uid and/or gid", cfg);
|
||||
exit(1);
|
||||
}
|
||||
if config.accounts.acct_type.is_none() && location.accounts.acct_type.is_none() {
|
||||
eprintln!("{}: [[location]][{}]: setuid: no acct-type set", cfg, idx);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
635
src/main.rs
Normal file
635
src/main.rs
Normal file
@@ -0,0 +1,635 @@
|
||||
#![doc(html_root_url = "https://docs.rs/webdav-server/0.4.0")]
|
||||
//! # `webdav-server` is a webdav server that handles user-accounts.
|
||||
//!
|
||||
//! This is a webdav server that allows access to a users home directory,
|
||||
//! just like an ancient FTP server would (remember those?).
|
||||
//!
|
||||
//! This is an application. There is no API documentation here.
|
||||
//! If you want to build your _own_ webdav server, use the `webdav-handler` crate.
|
||||
//!
|
||||
//! See the [GitHub repository](https://github.com/miquels/webdav-server-rs/)
|
||||
//! for documentation on how to run the server.
|
||||
//!
|
||||
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
|
||||
mod auth;
|
||||
mod cache;
|
||||
mod config;
|
||||
mod rootfs;
|
||||
#[doc(hidden)]
|
||||
pub mod router;
|
||||
mod suid;
|
||||
mod tls;
|
||||
mod unixuser;
|
||||
mod userfs;
|
||||
|
||||
use std::convert::TryFrom;
|
||||
use std::io;
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
use std::os::unix::io::{FromRawFd, AsRawFd};
|
||||
use std::process::exit;
|
||||
use std::sync::Arc;
|
||||
|
||||
use clap::clap_app;
|
||||
use headers::{authorization::Basic, Authorization, HeaderMapExt};
|
||||
use http::status::StatusCode;
|
||||
use hyper::{
|
||||
self,
|
||||
server::conn::{AddrIncoming, AddrStream},
|
||||
service::{make_service_fn, service_fn},
|
||||
};
|
||||
use tls_listener::TlsListener;
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use webdav_handler::{davpath::DavPath, DavConfig, DavHandler, DavMethod, DavMethodSet};
|
||||
use webdav_handler::{fakels::FakeLs, fs::DavFileSystem, ls::DavLockSystem};
|
||||
|
||||
use crate::config::{AcctType, Auth, CaseInsensitive, Handler, Location, OnNotfound};
|
||||
use crate::rootfs::RootFs;
|
||||
use crate::router::MatchedRoute;
|
||||
use crate::suid::proc_switch_ugid;
|
||||
use crate::tls::tls_acceptor;
|
||||
use crate::userfs::UserFs;
|
||||
|
||||
static PROGNAME: &'static str = "webdav-server";
|
||||
|
||||
// Contains "state" and a handle to the config.
|
||||
#[derive(Clone)]
|
||||
struct Server {
|
||||
dh: DavHandler,
|
||||
auth: auth::Auth,
|
||||
config: Arc<config::Config>,
|
||||
}
|
||||
|
||||
type HttpResult = Result<hyper::Response<webdav_handler::body::Body>, io::Error>;
|
||||
type HttpRequest = http::Request<hyper::Body>;
|
||||
|
||||
// Server implementation.
|
||||
impl Server {
|
||||
// Constructor.
|
||||
pub fn new(config: Arc<config::Config>, auth: auth::Auth) -> Self {
|
||||
// mostly empty handler.
|
||||
let ls = FakeLs::new() as Box<dyn DavLockSystem>;
|
||||
let dh = DavHandler::builder().locksystem(ls).build_handler();
|
||||
|
||||
Server { dh, auth, config }
|
||||
}
|
||||
|
||||
// check user account.
|
||||
async fn acct<'a>(
|
||||
&'a self,
|
||||
location: &Location,
|
||||
auth_user: Option<&'a String>,
|
||||
user_param: Option<&'a str>,
|
||||
) -> Result<Option<Arc<unixuser::User>>, StatusCode>
|
||||
{
|
||||
// Get username - if any.
|
||||
let user = match auth_user.map(|u| u.as_str()).or(user_param) {
|
||||
Some(u) => u,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
// If account is not set, fine.
|
||||
let acct_type = location
|
||||
.accounts
|
||||
.acct_type
|
||||
.as_ref()
|
||||
.or(self.config.accounts.acct_type.as_ref());
|
||||
match acct_type {
|
||||
Some(&AcctType::Unix) => {},
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
// check if user exists.
|
||||
let pwd = match cache::cached::unixuser(user, self.config.unix.aux_groups).await {
|
||||
Ok(pwd) => pwd,
|
||||
Err(_) => {
|
||||
debug!("acct: unix: user {} not found", user);
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
},
|
||||
};
|
||||
|
||||
// check minimum uid
|
||||
if let Some(min_uid) = self.config.unix.min_uid {
|
||||
if pwd.uid < min_uid {
|
||||
debug!("acct: {}: uid {} too low (<{})", pwd.name, pwd.uid, min_uid);
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
}
|
||||
Ok(Some(pwd))
|
||||
}
|
||||
|
||||
// return a new response::Builder with the Server and CORS header set.
|
||||
fn response_builder(&self) -> http::response::Builder {
|
||||
let mut builder = hyper::Response::builder();
|
||||
self.set_headers(builder.headers_mut().unwrap());
|
||||
builder
|
||||
}
|
||||
|
||||
// Set Server: webdav-server-rs header, and CORS.
|
||||
fn set_headers(&self, headers: &mut http::HeaderMap<http::header::HeaderValue>) {
|
||||
let id = self
|
||||
.config
|
||||
.server
|
||||
.identification
|
||||
.as_ref()
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("webdav-server-rs");
|
||||
if id != "" {
|
||||
headers.insert("server", id.parse().unwrap());
|
||||
}
|
||||
if self.config.server.cors {
|
||||
headers.insert("Access-Control-Allow-Origin", "*".parse().unwrap());
|
||||
headers.insert("Access-Control-Allow-Methods", "GET,HEAD,OPTIONS,PROPFIND".parse().unwrap());
|
||||
headers.insert("Access-Control-Allow-Headers", "DNT,Depth,Range".parse().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
// handle a request.
|
||||
async fn route(&self, req: HttpRequest, remote_ip: SocketAddr) -> HttpResult {
|
||||
// Get the URI path.
|
||||
let davpath = match DavPath::from_uri(req.uri()) {
|
||||
Ok(p) => p,
|
||||
Err(_) => return self.error(StatusCode::BAD_REQUEST).await,
|
||||
};
|
||||
let path = davpath.as_bytes();
|
||||
|
||||
// Get the method.
|
||||
let method = match DavMethod::try_from(req.method()) {
|
||||
Ok(m) => m,
|
||||
Err(_) => return self.error(http::StatusCode::METHOD_NOT_ALLOWED).await,
|
||||
};
|
||||
|
||||
// Request is stored here.
|
||||
let mut reqdata = Some(req);
|
||||
let mut got_match = false;
|
||||
|
||||
// Match routes to one or more locations.
|
||||
for route in self
|
||||
.config
|
||||
.router
|
||||
.matches(path, method, &["user", "path"])
|
||||
.drain(..)
|
||||
{
|
||||
got_match = true;
|
||||
|
||||
// Take the request from the option.
|
||||
let req = reqdata.take().unwrap();
|
||||
|
||||
// if we might continue, store a clone of the request for the next round.
|
||||
let location = &self.config.location[*route.data];
|
||||
if let Some(OnNotfound::Continue) = location.on_notfound {
|
||||
reqdata.get_or_insert(clone_httpreq(&req));
|
||||
}
|
||||
|
||||
// handle request.
|
||||
let res = self
|
||||
.handle(req, method, path, route, location, remote_ip.clone())
|
||||
.await?;
|
||||
|
||||
// no on_notfound? then this is final.
|
||||
if reqdata.is_none() || res.status() != StatusCode::NOT_FOUND {
|
||||
return Ok(res);
|
||||
}
|
||||
}
|
||||
|
||||
if !got_match {
|
||||
debug!("route: no matching route for {:?}", davpath);
|
||||
}
|
||||
|
||||
self.error(StatusCode::NOT_FOUND).await
|
||||
}
|
||||
|
||||
// handle a request.
|
||||
async fn handle<'a, 't: 'a, 'p: 'a>(
|
||||
&'a self,
|
||||
req: HttpRequest,
|
||||
method: DavMethod,
|
||||
path: &'a [u8],
|
||||
route: MatchedRoute<'t, 'p, usize>,
|
||||
location: &'a Location,
|
||||
remote_ip: SocketAddr,
|
||||
) -> HttpResult
|
||||
{
|
||||
// See if we matched a :user parameter
|
||||
// If so, it must be valid UTF-8, or we return NOT_FOUND.
|
||||
let user_param = match route.params[0].as_ref() {
|
||||
Some(p) => {
|
||||
match p.as_str() {
|
||||
Some(p) => Some(p),
|
||||
None => {
|
||||
debug!("handle: invalid utf-8 in :user part of path");
|
||||
return self.error(StatusCode::NOT_FOUND).await;
|
||||
},
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
// Do authentication if needed.
|
||||
let auth_hdr = req.headers().typed_get::<Authorization<Basic>>();
|
||||
let do_auth = match location.auth {
|
||||
Some(Auth::True) => true,
|
||||
Some(Auth::Write) => !DavMethodSet::WEBDAV_RO.contains(method) || auth_hdr.is_some(),
|
||||
Some(Auth::False) => false,
|
||||
Some(Auth::Opportunistic) | None => auth_hdr.is_some(),
|
||||
};
|
||||
let auth_user = if do_auth {
|
||||
let user = match self.auth.auth(&req, location, remote_ip).await {
|
||||
Ok(user) => user,
|
||||
Err(status) => return self.auth_error(status, location).await,
|
||||
};
|
||||
// if there was a :user in the route, return error if it does not match.
|
||||
if user_param.map(|u| u != &user).unwrap_or(false) {
|
||||
debug!("handle: auth user and :user mismatch");
|
||||
return self.auth_error(StatusCode::UNAUTHORIZED, location).await;
|
||||
}
|
||||
Some(user)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Now see if we want to do a account lookup, for uid/gid/homedir.
|
||||
let pwd = match self.acct(location, auth_user.as_ref(), user_param).await {
|
||||
Ok(pwd) => pwd,
|
||||
Err(status) => return self.auth_error(status, location).await,
|
||||
};
|
||||
|
||||
// Expand "~" in the directory.
|
||||
let dir = match expand_directory(location.directory.as_str(), pwd.as_ref()) {
|
||||
Ok(d) => d,
|
||||
Err(_) => return self.error(StatusCode::NOT_FOUND).await,
|
||||
};
|
||||
|
||||
// If :path matched, we can calculate the prefix.
|
||||
// If it didn't, the entire path _is_ the prefix.
|
||||
let prefix = match route.params[1].as_ref() {
|
||||
Some(p) => {
|
||||
let mut start = p.start();
|
||||
if start > 0 {
|
||||
start -= 1;
|
||||
}
|
||||
&path[..start]
|
||||
},
|
||||
None => path,
|
||||
};
|
||||
let prefix = match std::str::from_utf8(prefix) {
|
||||
Ok(p) => p.to_string(),
|
||||
Err(_) => {
|
||||
debug!("handle: prefix is non-UTF8");
|
||||
return self.error(StatusCode::NOT_FOUND).await;
|
||||
},
|
||||
};
|
||||
|
||||
// Get User-Agent for user-agent specific modes.
|
||||
let user_agent = req
|
||||
.headers()
|
||||
.get("user-agent")
|
||||
.and_then(|s| s.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
// Case insensitivity wanted?
|
||||
let case_insensitive = match location.case_insensitive {
|
||||
Some(CaseInsensitive::True) => true,
|
||||
Some(CaseInsensitive::Ms) => user_agent.contains("Microsoft"),
|
||||
Some(CaseInsensitive::False) | None => false,
|
||||
};
|
||||
|
||||
// macOS optimizations?
|
||||
let macos = user_agent.contains("WebDAVFS/") && user_agent.contains("Darwin");
|
||||
|
||||
// Get the filesystem.
|
||||
let auth_ugid = if location.setuid {
|
||||
pwd.as_ref().map(|p| (p.uid, p.gid, p.groups.as_slice()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let fs = match location.handler {
|
||||
Handler::Virtroot => {
|
||||
let auth_user = auth_user.as_ref().map(String::to_owned);
|
||||
RootFs::new(dir, auth_user, auth_ugid) as Box<dyn DavFileSystem>
|
||||
},
|
||||
Handler::Filesystem => {
|
||||
UserFs::new(dir, auth_ugid, true, case_insensitive, macos) as Box<dyn DavFileSystem>
|
||||
},
|
||||
};
|
||||
|
||||
// Build a handler.
|
||||
let methods = location
|
||||
.methods
|
||||
.unwrap_or(DavMethodSet::from_vec(vec!["GET", "HEAD"]).unwrap());
|
||||
let hide_symlinks = location.hide_symlinks.clone().unwrap_or(true);
|
||||
|
||||
let mut config = DavConfig::new()
|
||||
.filesystem(fs)
|
||||
.strip_prefix(prefix)
|
||||
.methods(methods)
|
||||
.hide_symlinks(hide_symlinks)
|
||||
.autoindex(location.autoindex);
|
||||
if let Some(auth_user) = auth_user {
|
||||
config = config.principal(auth_user);
|
||||
}
|
||||
if let Some(indexfile) = location.indexfile.clone() {
|
||||
config = config.indexfile(indexfile);
|
||||
}
|
||||
|
||||
// All set.
|
||||
self.run_davhandler(config, req).await
|
||||
}
|
||||
|
||||
async fn build_error(&self, code: StatusCode, location: Option<&Location>) -> HttpResult {
|
||||
let msg = format!(
|
||||
"<error>{} {}</error>\n",
|
||||
code.as_u16(),
|
||||
code.canonical_reason().unwrap_or("")
|
||||
);
|
||||
let mut response = self
|
||||
.response_builder()
|
||||
.status(code)
|
||||
.header("Content-Type", "text/xml");
|
||||
if code == StatusCode::UNAUTHORIZED {
|
||||
let realm = location.and_then(|location| location.accounts.realm.as_ref());
|
||||
let realm = realm.or(self.config.accounts.realm.as_ref());
|
||||
let realm = realm.map(|s| s.as_str()).unwrap_or("Webdav Server");
|
||||
response = response.header("WWW-Authenticate", format!("Basic realm=\"{}\"", realm).as_str());
|
||||
}
|
||||
Ok(response.body(msg.into()).unwrap())
|
||||
}
|
||||
|
||||
async fn auth_error(&self, code: StatusCode, location: &Location) -> HttpResult {
|
||||
self.build_error(code, Some(location)).await
|
||||
}
|
||||
|
||||
async fn error(&self, code: StatusCode) -> HttpResult {
|
||||
self.build_error(code, None).await
|
||||
}
|
||||
|
||||
// Call the davhandler, then add headers to the response.
|
||||
async fn run_davhandler(&self, config: DavConfig, req: HttpRequest) -> HttpResult {
|
||||
let resp = self.dh.handle_with(config, req).await;
|
||||
let (mut parts, body) = resp.into_parts();
|
||||
self.set_headers(&mut parts.headers);
|
||||
Ok(http::Response::from_parts(parts, body))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// command line option processing.
|
||||
let matches = clap_app!(webdav_server =>
|
||||
(version: "0.3")
|
||||
(@arg CFG: -c --config +takes_value "configuration file (/etc/webdav-server.toml)")
|
||||
(@arg PORT: -p --port +takes_value "listen to this port on localhost only")
|
||||
(@arg DBG: -D --debug "enable debug level logging")
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
if matches.is_present("DBG") {
|
||||
use env_logger::Env;
|
||||
let level = "webdav_server=debug,webdav_handler=debug";
|
||||
env_logger::Builder::from_env(Env::default().default_filter_or(level)).init();
|
||||
} else {
|
||||
env_logger::init();
|
||||
}
|
||||
|
||||
let port = matches.value_of("PORT");
|
||||
let cfg = matches.value_of("CFG").unwrap_or("/etc/webdav-server.toml");
|
||||
|
||||
// read config.
|
||||
let mut config = match config::read(cfg.clone()) {
|
||||
Err(e) => {
|
||||
eprintln!("{}: {}: {}", PROGNAME, cfg, e);
|
||||
exit(1);
|
||||
},
|
||||
Ok(c) => c,
|
||||
};
|
||||
config::check(cfg.clone(), &config);
|
||||
|
||||
// build routes.
|
||||
if let Err(e) = config::build_routes(cfg.clone(), &mut config) {
|
||||
eprintln!("{}: {}: {}", PROGNAME, cfg, e);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if let Some(port) = port {
|
||||
let localhosts = vec![
|
||||
("127.0.0.1:".to_string() + port).parse::<SocketAddr>().unwrap(),
|
||||
("[::]:".to_string() + port).parse::<SocketAddr>().unwrap(),
|
||||
];
|
||||
config.server.listen = config::OneOrManyAddr::Many(localhosts);
|
||||
}
|
||||
let config = Arc::new(config);
|
||||
|
||||
// set cache timeouts.
|
||||
if let Some(timeout) = config.unix.cache_timeout {
|
||||
cache::cached::set_pwcache_timeout(timeout);
|
||||
}
|
||||
|
||||
// resolve addresses.
|
||||
let addrs = config.server.listen.clone().to_socket_addrs().unwrap_or_else(|e| {
|
||||
eprintln!("{}: {}: [server] listen: {:?}", PROGNAME, cfg, e);
|
||||
exit(1);
|
||||
});
|
||||
let tls_addrs = config.server.tls_listen.clone().to_socket_addrs().unwrap_or_else(|e| {
|
||||
eprintln!("{}: {}: [server] listen: {:?}", PROGNAME, cfg, e);
|
||||
exit(1);
|
||||
});
|
||||
|
||||
// initialize auth early.
|
||||
let auth = auth::Auth::new(config.clone())?;
|
||||
|
||||
// start tokio runtime and initialize the rest from within the runtime.
|
||||
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_io()
|
||||
.enable_time()
|
||||
.build()?;
|
||||
|
||||
rt.block_on(async move {
|
||||
// build servers (one for each listen address).
|
||||
let dav_server = Server::new(config.clone(), auth);
|
||||
let mut servers = Vec::new();
|
||||
let mut tls_servers = Vec::new();
|
||||
|
||||
// Plaintext servers.
|
||||
for sockaddr in addrs {
|
||||
let listener = match make_listener(sockaddr) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
eprintln!("{}: listener on {:?}: {}", PROGNAME, &sockaddr, e);
|
||||
exit(1);
|
||||
},
|
||||
};
|
||||
let dav_server = dav_server.clone();
|
||||
let make_service = make_service_fn(move |socket: &AddrStream| {
|
||||
let dav_server = dav_server.clone();
|
||||
let remote_addr = socket.remote_addr();
|
||||
async move {
|
||||
let func = move |req| {
|
||||
let dav_server = dav_server.clone();
|
||||
async move { dav_server.route(req, remote_addr).await }
|
||||
};
|
||||
Ok::<_, hyper::Error>(service_fn(func))
|
||||
}
|
||||
});
|
||||
let incoming = AddrIncoming::from_listener(listener)?;
|
||||
let server = hyper::Server::builder(incoming);
|
||||
println!("Listening on http://{:?}", sockaddr);
|
||||
|
||||
servers.push(async move {
|
||||
if let Err(e) = server.serve(make_service).await {
|
||||
eprintln!("{}: server error: {}", PROGNAME, e);
|
||||
exit(1);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// TLS servers.
|
||||
if tls_addrs.len() > 0 {
|
||||
let tls_acceptor = tls_acceptor(&config.server)?;
|
||||
|
||||
for sockaddr in tls_addrs {
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let listener = make_listener(sockaddr).unwrap_or_else(|e| {
|
||||
eprintln!("{}: listener on {:?}: {}", PROGNAME, &sockaddr, e);
|
||||
exit(1);
|
||||
});
|
||||
let dav_server = dav_server.clone();
|
||||
let make_service = make_service_fn(move |stream: &TlsStream<AddrStream>| {
|
||||
let dav_server = dav_server.clone();
|
||||
let remote_addr = stream.get_ref().0.remote_addr();
|
||||
async move {
|
||||
let func = move |req| {
|
||||
let dav_server = dav_server.clone();
|
||||
async move { dav_server.route(req, remote_addr).await }
|
||||
};
|
||||
Ok::<_, hyper::Error>(service_fn(func))
|
||||
}
|
||||
});
|
||||
|
||||
// Since the server can exit when there's an error on the TlsStream,
|
||||
// we run it in a loop. Every time the loop is entered we dup() the
|
||||
// listening fd and create a new TcpListener. This way, we should
|
||||
// not lose any pending connections during a restart.
|
||||
let master_listen_fd = listener.as_raw_fd();
|
||||
std::mem::forget(listener);
|
||||
|
||||
println!("Listening on https://{:?}", sockaddr);
|
||||
tls_servers.push(async move {
|
||||
loop {
|
||||
// reuse the incoming socket after the server exits.
|
||||
let listen_fd = match nix::unistd::dup(master_listen_fd) {
|
||||
Ok(fd) => fd,
|
||||
Err(e) => {
|
||||
eprintln!("{}: server error: dup: {}", PROGNAME, e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
// SAFETY: listen_fd is unique (we just dup'ed it).
|
||||
let std_listen = unsafe { std::net::TcpListener::from_raw_fd(listen_fd) };
|
||||
let listener = match tokio::net::TcpListener::from_std(std_listen) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
eprintln!("{}: server error: new TcpListener: {}", PROGNAME, e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
let a_incoming = match AddrIncoming::from_listener(listener) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
eprintln!("{}: server error: new AddrIncoming: {}", PROGNAME, e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
let incoming = TlsListener::new(tls_acceptor.clone(), a_incoming);
|
||||
let server = hyper::Server::builder(incoming);
|
||||
if let Err(e) = server.serve(make_service.clone()).await {
|
||||
eprintln!("{}: server error: {} (retrying)", PROGNAME, e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// drop privs.
|
||||
match (&config.server.uid, &config.server.gid) {
|
||||
(&Some(uid), &Some(gid)) => {
|
||||
if !suid::have_suid_privs() {
|
||||
eprintln!(
|
||||
"{}: insufficent priviliges to switch uid/gid (not root).",
|
||||
PROGNAME
|
||||
);
|
||||
exit(1);
|
||||
}
|
||||
let keep_privs = config.location.iter().any(|l| l.setuid);
|
||||
proc_switch_ugid(uid, gid, keep_privs);
|
||||
},
|
||||
_ => {},
|
||||
}
|
||||
|
||||
// spawn all servers, and wait for them to finish.
|
||||
let mut tasks = Vec::new();
|
||||
for server in servers.drain(..) {
|
||||
tasks.push(tokio::spawn(server));
|
||||
}
|
||||
for server in tls_servers.drain(..) {
|
||||
tasks.push(tokio::spawn(server));
|
||||
}
|
||||
for task in tasks.drain(..) {
|
||||
let _ = task.await;
|
||||
}
|
||||
|
||||
Ok::<_, Box<dyn std::error::Error>>(())
|
||||
})
|
||||
}
|
||||
|
||||
// Clones a http request with an empty body.
|
||||
fn clone_httpreq(req: &HttpRequest) -> HttpRequest {
|
||||
let mut builder = http::Request::builder()
|
||||
.method(req.method().clone())
|
||||
.uri(req.uri().clone())
|
||||
.version(req.version().clone());
|
||||
for (name, value) in req.headers().iter() {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
builder.body(hyper::Body::empty()).unwrap()
|
||||
}
|
||||
|
||||
fn expand_directory(dir: &str, pwd: Option<&Arc<unixuser::User>>) -> Result<String, StatusCode> {
|
||||
// If it doesn't start with "~", skip.
|
||||
if !dir.starts_with("~") {
|
||||
return Ok(dir.to_string());
|
||||
}
|
||||
// ~whatever doesn't work.
|
||||
if dir.len() > 1 && !dir.starts_with("~/") {
|
||||
debug!("expand_directory: rejecting {}", dir);
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
// must have a directory, and that dir must be UTF-8.
|
||||
let pwd = match pwd {
|
||||
Some(pwd) => pwd,
|
||||
None => {
|
||||
debug!("expand_directory: cannot expand {}: no account", dir);
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
},
|
||||
};
|
||||
let homedir = pwd.dir.to_str().ok_or(StatusCode::NOT_FOUND)?;
|
||||
Ok(format!("{}/{}", homedir, &dir[1..]))
|
||||
}
|
||||
|
||||
// Make a new TcpListener, and if it's a V6 listener, set the
|
||||
// V6_V6ONLY socket option on it.
|
||||
fn make_listener(addr: SocketAddr) -> io::Result<tokio::net::TcpListener> {
|
||||
use socket2::{Domain, SockAddr, Socket, Type, Protocol};
|
||||
let s = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
|
||||
if addr.is_ipv6() {
|
||||
s.set_only_v6(true)?;
|
||||
}
|
||||
s.set_nonblocking(true)?;
|
||||
s.set_nodelay(true)?;
|
||||
s.set_reuse_address(true)?;
|
||||
let addr: SockAddr = addr.into();
|
||||
s.bind(&addr)?;
|
||||
s.listen(128)?;
|
||||
let listener: std::net::TcpListener = s.into();
|
||||
tokio::net::TcpListener::from_std(listener)
|
||||
}
|
||||
112
src/rootfs.rs
Normal file
112
src/rootfs.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
//
|
||||
// Virtual Root filesystem for PROPFIND.
|
||||
//
|
||||
// Shows "/" and "/user".
|
||||
//
|
||||
use std;
|
||||
use std::path::Path;
|
||||
|
||||
use futures::future::{self, FutureExt};
|
||||
use webdav_handler::davpath::DavPath;
|
||||
use webdav_handler::fs::*;
|
||||
|
||||
use crate::userfs::UserFs;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RootFs {
|
||||
user: String,
|
||||
fs: UserFs,
|
||||
}
|
||||
|
||||
impl RootFs {
|
||||
pub fn new<P>(dir: P, user: Option<String>, creds: Option<(u32, u32, &[u32])>) -> Box<RootFs>
|
||||
where P: AsRef<Path> + Clone {
|
||||
Box::new(RootFs {
|
||||
user: user.unwrap_or("".to_string()),
|
||||
fs: *UserFs::new(dir, creds, false, false, true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl DavFileSystem for RootFs {
|
||||
// Only allow "/" or "/user", for both return the metadata of the UserFs root.
|
||||
fn metadata<'a>(&'a self, path: &'a DavPath) -> FsFuture<Box<dyn DavMetaData>> {
|
||||
async move {
|
||||
let b = path.as_bytes();
|
||||
if b != b"/" && &b[1..] != self.user.as_bytes() {
|
||||
return Err(FsError::NotFound);
|
||||
}
|
||||
let path = DavPath::new("/").unwrap();
|
||||
self.fs.metadata(&path).await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
// Only return one entry: "user".
|
||||
fn read_dir<'a>(
|
||||
&'a self,
|
||||
path: &'a DavPath,
|
||||
_meta: ReadDirMeta,
|
||||
) -> FsFuture<FsStream<Box<dyn DavDirEntry>>>
|
||||
{
|
||||
Box::pin(async move {
|
||||
let mut v = Vec::new();
|
||||
if self.user != "" {
|
||||
v.push(RootFsDirEntry {
|
||||
name: self.user.clone(),
|
||||
meta: self.fs.metadata(path).await,
|
||||
});
|
||||
}
|
||||
let strm = futures::stream::iter(RootFsReadDir {
|
||||
iterator: v.into_iter(),
|
||||
});
|
||||
Ok(Box::pin(strm) as FsStream<Box<dyn DavDirEntry>>)
|
||||
})
|
||||
}
|
||||
|
||||
// cannot open any files.
|
||||
fn open(&self, _path: &DavPath, _options: OpenOptions) -> FsFuture<Box<dyn DavFile>> {
|
||||
Box::pin(future::ready(Err(FsError::NotImplemented)))
|
||||
}
|
||||
|
||||
// forward quota.
|
||||
fn get_quota(&self) -> FsFuture<(u64, Option<u64>)> {
|
||||
self.fs.get_quota()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RootFsReadDir {
|
||||
iterator: std::vec::IntoIter<RootFsDirEntry>,
|
||||
}
|
||||
|
||||
impl Iterator for RootFsReadDir {
|
||||
type Item = Box<dyn DavDirEntry>;
|
||||
|
||||
fn next(&mut self) -> Option<Box<dyn DavDirEntry>> {
|
||||
match self.iterator.next() {
|
||||
None => return None,
|
||||
Some(entry) => Some(Box::new(entry)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RootFsDirEntry {
|
||||
meta: FsResult<Box<dyn DavMetaData>>,
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl DavDirEntry for RootFsDirEntry {
|
||||
fn metadata(&self) -> FsFuture<Box<dyn DavMetaData>> {
|
||||
Box::pin(future::ready(self.meta.clone()))
|
||||
}
|
||||
|
||||
fn name(&self) -> Vec<u8> {
|
||||
self.name.as_bytes().to_vec()
|
||||
}
|
||||
|
||||
fn is_dir(&self) -> FsFuture<bool> {
|
||||
Box::pin(future::ready(Ok(true)))
|
||||
}
|
||||
}
|
||||
262
src/router.rs
Normal file
262
src/router.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
//!
|
||||
//! Simple and stupid HTTP router.
|
||||
//!
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use regex::bytes::{Match, Regex, RegexSet};
|
||||
use webdav_handler::{DavMethod, DavMethodSet};
|
||||
|
||||
// internal representation of a route.
|
||||
#[derive(Debug)]
|
||||
struct Route<T: Debug> {
|
||||
regex: Regex,
|
||||
methods: Option<DavMethodSet>,
|
||||
data: T,
|
||||
}
|
||||
|
||||
/// A matched route.
|
||||
#[derive(Debug)]
|
||||
pub struct MatchedRoute<'t, 'p, T: Debug> {
|
||||
pub methods: Option<DavMethodSet>,
|
||||
pub params: Vec<Option<Param<'p>>>,
|
||||
pub data: &'t T,
|
||||
}
|
||||
|
||||
/// A parameter on a matched route.
|
||||
pub struct Param<'p>(Match<'p>);
|
||||
|
||||
impl Debug for Param<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.debug_struct("Param")
|
||||
.field("start", &self.0.start())
|
||||
.field("end", &self.0.end())
|
||||
.field("as_str", &std::str::from_utf8(self.0.as_bytes()).ok())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'p> Param<'p> {
|
||||
/// Returns the starting byte offset of the match in the path.
|
||||
#[inline]
|
||||
pub fn start(&self) -> usize {
|
||||
self.0.start()
|
||||
}
|
||||
|
||||
/// Returns the ending byte offset of the match in the path.
|
||||
#[inline]
|
||||
pub fn end(&self) -> usize {
|
||||
self.0.end()
|
||||
}
|
||||
|
||||
/// Returns the matched part of the path.
|
||||
#[inline]
|
||||
pub fn as_bytes(&self) -> &'p [u8] {
|
||||
self.0.as_bytes()
|
||||
}
|
||||
|
||||
/// Returns the matched part of the path as a &str, if it is valid utf-8.
|
||||
#[inline]
|
||||
pub fn as_str(&self) -> Option<&'p str> {
|
||||
std::str::from_utf8(self.0.as_bytes()).ok()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Builder<T: Debug> {
|
||||
routes: Vec<Route<T>>,
|
||||
}
|
||||
|
||||
impl<T: Debug> Builder<T> {
|
||||
/// Add a route.
|
||||
///
|
||||
/// Routes are matched in the order they were added.
|
||||
///
|
||||
/// If a route starts with '^', it's assumed that it is a regular
|
||||
/// expression. Parameters are included as "named capture groups".
|
||||
///
|
||||
/// Otherwise, it's a route-expression, with just the normal :params
|
||||
/// and *splat param, and parts between parentheses are optional.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// - /api/get/:id
|
||||
/// - /files/*path
|
||||
/// - /users(/)
|
||||
/// - /users(/*path)
|
||||
///
|
||||
pub fn add(
|
||||
&mut self,
|
||||
route: impl AsRef<str>,
|
||||
methods: Option<DavMethodSet>,
|
||||
data: T,
|
||||
) -> Result<&mut Self, regex::Error>
|
||||
{
|
||||
let route = route.as_ref();
|
||||
// Might be a regexp
|
||||
if route.starts_with("^") {
|
||||
return self.add_re(route, methods, data);
|
||||
}
|
||||
// Ignore it if it does not start with /
|
||||
if !route.starts_with("/") {
|
||||
return Ok(self);
|
||||
}
|
||||
|
||||
// First, replace special characters "()*" with unicode chars
|
||||
// from the private-use area, so that we can then regex-escape
|
||||
// the entire string.
|
||||
let re_route = route
|
||||
.chars()
|
||||
.map(|c| {
|
||||
match c {
|
||||
'*' => '\u{e001}',
|
||||
'(' => '\u{e002}',
|
||||
')' => '\u{e003}',
|
||||
'\u{e001}' => ' ',
|
||||
'\u{e002}' => ' ',
|
||||
'\u{e003}' => ' ',
|
||||
c => c,
|
||||
}
|
||||
})
|
||||
.collect::<String>();
|
||||
let re_route = regex::escape(&re_route);
|
||||
|
||||
// Translate route expression into regexp.
|
||||
// We do a simple transformation:
|
||||
// :ident -> (?P<ident>[^/]*)
|
||||
// *ident -> (?P<ident>.*)
|
||||
// (text) -> (?:text|)
|
||||
lazy_static! {
|
||||
static ref COLON: Regex = Regex::new(":([a-zA-Z0-9]+)").unwrap();
|
||||
static ref SPLAT: Regex = Regex::new("\u{e001}([a-zA-Z0-9]+)").unwrap();
|
||||
static ref MAYBE: Regex = Regex::new("\u{e002}([^\u{e002}]*)\u{e003}").unwrap();
|
||||
};
|
||||
let mut re_route = re_route.into_bytes();
|
||||
re_route = COLON.replace_all(&re_route, &b"(?P<$1>[^/]*)"[..]).to_vec();
|
||||
re_route = SPLAT.replace_all(&re_route, &b"(?P<$1>.*)"[..]).to_vec();
|
||||
re_route = MAYBE.replace_all(&re_route, &b"($1)?"[..]).to_vec();
|
||||
|
||||
// finalize regex.
|
||||
let re_route = "^".to_string() + &String::from_utf8(re_route).unwrap() + "$";
|
||||
|
||||
self.add_re(&re_route, methods, data)
|
||||
}
|
||||
|
||||
// add route as regular expression.
|
||||
fn add_re(&mut self, s: &str, methods: Option<DavMethodSet>, data: T) -> Result<&mut Self, regex::Error> {
|
||||
// Set flags: enable ". matches everything", disable strict unicode.
|
||||
// We known 's' starts with "^", add it after that.
|
||||
let s2 = format!("^(?s){}", &s[1..]);
|
||||
let regex = Regex::new(&s2)?;
|
||||
self.routes.push(Route { regex, methods, data });
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Combine all the routes and compile them into an internal RegexSet.
|
||||
pub fn build(&mut self) -> Router<T> {
|
||||
let set = RegexSet::new(self.routes.iter().map(|r| r.regex.as_str())).unwrap();
|
||||
Router {
|
||||
routes: std::mem::replace(&mut self.routes, Vec::new()),
|
||||
set,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dead simple HTTP router.
|
||||
#[derive(Debug)]
|
||||
pub struct Router<T: Debug> {
|
||||
set: RegexSet,
|
||||
routes: Vec<Route<T>>,
|
||||
}
|
||||
|
||||
impl<T: Debug> Default for Router<T> {
|
||||
fn default() -> Router<T> {
|
||||
Router {
|
||||
set: RegexSet::new(&[] as &[&str]).unwrap(),
|
||||
routes: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug> Router<T> {
|
||||
/// Return a builder.
|
||||
pub fn builder() -> Builder<T> {
|
||||
Builder { routes: Vec::new() }
|
||||
}
|
||||
|
||||
/// See if the path matches a route in the set.
|
||||
///
|
||||
/// The names of the parameters you want to be returned need to be passed in as an array.
|
||||
pub fn matches<'a>(
|
||||
&self,
|
||||
path: &'a [u8],
|
||||
method: DavMethod,
|
||||
param_names: &[&str],
|
||||
) -> Vec<MatchedRoute<'_, 'a, T>>
|
||||
{
|
||||
let mut matched = Vec::new();
|
||||
for idx in self.set.matches(path) {
|
||||
let route = &self.routes[idx];
|
||||
if route.methods.map(|m| m.contains(method)).unwrap_or(true) {
|
||||
let mut params = Vec::new();
|
||||
if let Some(caps) = route.regex.captures(path) {
|
||||
for name in param_names {
|
||||
params.push(caps.name(name).map(|p| Param(p)));
|
||||
}
|
||||
} else {
|
||||
for _ in param_names {
|
||||
params.push(None);
|
||||
}
|
||||
}
|
||||
matched.push(MatchedRoute {
|
||||
methods: route.methods,
|
||||
params,
|
||||
data: &route.data,
|
||||
});
|
||||
}
|
||||
}
|
||||
matched
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use webdav_handler::DavMethod;
|
||||
|
||||
fn test_match(rtr: &Router<usize>, p: &[u8], user: &str, path: &str) {
|
||||
let x = rtr.matches(p, DavMethod::Get, &["user", "path"]);
|
||||
assert!(x.len() > 0);
|
||||
let x = &x[0];
|
||||
if user != "" {
|
||||
assert!(x.params[0]
|
||||
.as_ref()
|
||||
.map(|b| b.as_bytes() == user.as_bytes())
|
||||
.unwrap_or(false));
|
||||
}
|
||||
if path != "" {
|
||||
assert!(x.params[1]
|
||||
.as_ref()
|
||||
.map(|b| b.as_bytes() == path.as_bytes())
|
||||
.unwrap_or(false));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_router() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let rtr = Router::<usize>::builder()
|
||||
.add("/", None, 1)?
|
||||
.add("/users(/:user)", None, 2)?
|
||||
.add("/files/*path", None, 3)?
|
||||
.add("/files(/*path)", None, 4)?
|
||||
.build();
|
||||
|
||||
test_match(&rtr, b"/", "", "");
|
||||
test_match(&rtr, b"/users", "", "");
|
||||
test_match(&rtr, b"/users/", "", "");
|
||||
test_match(&rtr, b"/users/mike", "mike", "");
|
||||
test_match(&rtr, b"/files/foo/bar", "", "foo/bar");
|
||||
test_match(&rtr, b"/files", "", "");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
304
src/suid.rs
Normal file
304
src/suid.rs
Normal file
@@ -0,0 +1,304 @@
|
||||
use std::io;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
static THREAD_SWITCH_UGID_USED: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
#[cfg(all(target_os = "linux"))]
|
||||
mod setuid {
|
||||
// On x86, the default SYS_setresuid is 16 bits. We need to
|
||||
// import the 32-bit variant.
|
||||
#[cfg(target_arch = "x86")]
|
||||
mod uid32 {
|
||||
pub use libc::SYS_getgroups32 as SYS_getgroups;
|
||||
pub use libc::SYS_setgroups32 as SYS_setgroups;
|
||||
pub use libc::SYS_setresgid32 as SYS_setresgid;
|
||||
pub use libc::SYS_setresuid32 as SYS_setresuid;
|
||||
}
|
||||
#[cfg(not(target_arch = "x86"))]
|
||||
mod uid32 {
|
||||
pub use libc::{SYS_getgroups, SYS_setgroups, SYS_setresgid, SYS_setresuid};
|
||||
}
|
||||
use self::uid32::*;
|
||||
use std::cell::RefCell;
|
||||
use std::convert::TryInto;
|
||||
use std::io;
|
||||
use std::sync::atomic::Ordering;
|
||||
const ID_NONE: libc::uid_t = 0xffffffff;
|
||||
|
||||
// current credentials of this thread.
|
||||
struct UgidState {
|
||||
ruid: u32,
|
||||
euid: u32,
|
||||
rgid: u32,
|
||||
egid: u32,
|
||||
groups: Vec<u32>,
|
||||
}
|
||||
|
||||
impl UgidState {
|
||||
fn new() -> UgidState {
|
||||
super::THREAD_SWITCH_UGID_USED.store(true, Ordering::Release);
|
||||
UgidState {
|
||||
ruid: unsafe { libc::getuid() } as u32,
|
||||
euid: unsafe { libc::geteuid() } as u32,
|
||||
rgid: unsafe { libc::getgid() } as u32,
|
||||
egid: unsafe { libc::getegid() } as u32,
|
||||
groups: getgroups().expect("UgidState::new"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn getgroups() -> io::Result<Vec<u32>> {
|
||||
// get number of groups.
|
||||
let size = unsafe {
|
||||
libc::syscall(
|
||||
SYS_getgroups,
|
||||
0 as libc::c_int,
|
||||
std::ptr::null_mut::<libc::gid_t>(),
|
||||
)
|
||||
};
|
||||
if size < 0 {
|
||||
return Err(oserr(size, "getgroups(0, NULL)"));
|
||||
}
|
||||
|
||||
// get groups.
|
||||
let mut groups = Vec::<u32>::with_capacity(size as usize);
|
||||
groups.resize(size as usize, 0);
|
||||
let res = unsafe { libc::syscall(SYS_getgroups, size as libc::c_int, groups.as_mut_ptr() as *mut _) };
|
||||
|
||||
// sanity check.
|
||||
if res != size {
|
||||
if res < 0 {
|
||||
return Err(oserr(res, format!("getgroups({}, buffer)", size)));
|
||||
}
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("getgroups({}, buffer): returned {}", size, res),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(groups)
|
||||
}
|
||||
|
||||
fn oserr(code: libc::c_long, msg: impl AsRef<str>) -> io::Error {
|
||||
let msg = msg.as_ref();
|
||||
let err = io::Error::from_raw_os_error(code.try_into().unwrap());
|
||||
io::Error::new(err.kind(), format!("{}: {}", msg, err))
|
||||
}
|
||||
|
||||
// thread-local seteuid.
|
||||
fn seteuid(uid: u32) -> io::Result<()> {
|
||||
let res = unsafe { libc::syscall(SYS_setresuid, ID_NONE, uid, ID_NONE) };
|
||||
if res < 0 {
|
||||
return Err(oserr(res, format!("seteuid({})", uid)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// thread-local setegid.
|
||||
fn setegid(gid: u32) -> io::Result<()> {
|
||||
let res = unsafe { libc::syscall(SYS_setresgid, ID_NONE, gid, ID_NONE) };
|
||||
if res < 0 {
|
||||
return Err(oserr(res, format!("setegid({})", gid)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// thread-local setgroups.
|
||||
fn setgroups(gids: &[u32]) -> io::Result<()> {
|
||||
let size = gids.len() as libc::c_int;
|
||||
let res = unsafe { libc::syscall(SYS_setgroups, size, gids.as_ptr() as *const libc::gid_t) };
|
||||
if res < 0 {
|
||||
return Err(oserr(res, format!("setgroups({}, {:?}", size, gids)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// credential state is thread-local.
|
||||
thread_local!(static CURRENT_UGID: RefCell<UgidState> = RefCell::new(UgidState::new()));
|
||||
|
||||
/// Switch thread credentials.
|
||||
pub(super) fn thread_switch_ugid(newuid: u32, newgid: u32, newgroups: &[u32]) -> (u32, u32, Vec<u32>) {
|
||||
CURRENT_UGID.with(|current_ugid| {
|
||||
let mut cur = current_ugid.borrow_mut();
|
||||
let (olduid, oldgid, oldgroups) = (cur.euid, cur.egid, cur.groups.clone());
|
||||
let groups_changed = newgroups != cur.groups.as_slice();
|
||||
|
||||
// Check if anything changed.
|
||||
if newuid != cur.euid || newgid != cur.egid || groups_changed {
|
||||
// See if we have to switch to root privs first.
|
||||
if cur.euid != 0 && (newuid != cur.ruid || newgid != cur.rgid || groups_changed) {
|
||||
// Must first switch to root.
|
||||
if let Err(e) = seteuid(0) {
|
||||
panic!("{}", e);
|
||||
}
|
||||
cur.euid = 0;
|
||||
}
|
||||
|
||||
if newgid != cur.egid {
|
||||
// Change gid.
|
||||
if let Err(e) = setegid(newgid) {
|
||||
panic!("{}", e);
|
||||
}
|
||||
cur.egid = newgid;
|
||||
}
|
||||
if groups_changed {
|
||||
// Change groups.
|
||||
if let Err(e) = setgroups(newgroups) {
|
||||
panic!("{}", e);
|
||||
}
|
||||
cur.groups.truncate(0);
|
||||
cur.groups.extend_from_slice(newgroups);
|
||||
}
|
||||
if newuid != cur.euid {
|
||||
// Change uid.
|
||||
if let Err(e) = seteuid(newuid) {
|
||||
panic!("{}", e);
|
||||
}
|
||||
cur.euid = newuid;
|
||||
}
|
||||
}
|
||||
(olduid, oldgid, oldgroups)
|
||||
})
|
||||
}
|
||||
|
||||
// Yep..
|
||||
pub fn has_thread_switch_ugid() -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
mod setuid {
|
||||
// Not implemented, as it looks like only Linux has support for
|
||||
// per-thread uid/gid switching.
|
||||
//
|
||||
// DO NOT implement this through libc::setuid, as that will
|
||||
// switch the uids of all threads.
|
||||
//
|
||||
/// Switch thread credentials. Not implemented!
|
||||
pub(super) fn thread_switch_ugid(_newuid: u32, _newgid: u32, _newgroups: &[u32]) -> (u32, u32, Vec<u32>) {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
// Nope.
|
||||
pub fn has_thread_switch_ugid() -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub use self::setuid::has_thread_switch_ugid;
|
||||
use self::setuid::thread_switch_ugid;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct UgidCreds {
|
||||
pub uid: u32,
|
||||
pub gid: u32,
|
||||
pub groups: Vec<u32>,
|
||||
}
|
||||
|
||||
pub struct UgidSwitch {
|
||||
target_creds: Option<UgidCreds>,
|
||||
}
|
||||
|
||||
pub struct UgidSwitchGuard {
|
||||
base_creds: Option<UgidCreds>,
|
||||
}
|
||||
|
||||
impl UgidSwitch {
|
||||
pub fn new(creds: Option<(u32, u32, &[u32])>) -> UgidSwitch {
|
||||
let target_creds = match creds {
|
||||
Some((uid, gid, groups)) => {
|
||||
Some(UgidCreds {
|
||||
uid,
|
||||
gid,
|
||||
groups: groups.into(),
|
||||
})
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
UgidSwitch { target_creds }
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn run<F, R>(&self, func: F) -> R
|
||||
where F: FnOnce() -> R {
|
||||
let _guard = self.guard();
|
||||
func()
|
||||
}
|
||||
|
||||
pub fn guard(&self) -> UgidSwitchGuard {
|
||||
match &self.target_creds {
|
||||
&None => UgidSwitchGuard { base_creds: None },
|
||||
&Some(ref creds) => {
|
||||
let (uid, gid, groups) = thread_switch_ugid(creds.uid, creds.gid, &creds.groups);
|
||||
UgidSwitchGuard {
|
||||
base_creds: Some(UgidCreds { uid, gid, groups }),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for UgidSwitchGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Some(ref creds) = self.base_creds {
|
||||
thread_switch_ugid(creds.uid, creds.gid, &creds.groups);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Switch process credentials. Keeps the saved-uid as root, so that
|
||||
/// we can switch to other ids later on.
|
||||
pub fn proc_switch_ugid(uid: u32, gid: u32, keep_privs: bool) {
|
||||
if THREAD_SWITCH_UGID_USED.load(Ordering::Acquire) {
|
||||
panic!("proc_switch_ugid: called after thread_switch_ugid() has been used");
|
||||
}
|
||||
|
||||
fn last_os_error() -> io::Error {
|
||||
io::Error::last_os_error()
|
||||
}
|
||||
|
||||
unsafe {
|
||||
// first get full root privs (real, effective, and saved uids)
|
||||
if libc::setuid(0) != 0 {
|
||||
panic!("libc::setuid(0): {:?}", last_os_error());
|
||||
}
|
||||
|
||||
// set real uid, and keep effective uid at 0.
|
||||
#[cfg(not(any(target_os = "openbsd", target_os = "freebsd")))]
|
||||
if libc::setreuid(uid, 0) != 0 {
|
||||
panic!("libc::setreuid({}, 0): {:?}", uid, last_os_error());
|
||||
}
|
||||
#[cfg(any(target_os = "openbsd", target_os = "freebsd"))]
|
||||
if libc::setresuid(uid, 0, 0) != 0 {
|
||||
panic!("libc::setreuid({}, 0): {:?}", uid, last_os_error());
|
||||
}
|
||||
|
||||
// set group id.
|
||||
if libc::setgid(gid) != 0 {
|
||||
panic!("libc::setgid({}): {:?}", gid, last_os_error());
|
||||
}
|
||||
|
||||
// remove _all_ auxilary groups.
|
||||
if libc::setgroups(0, std::ptr::null::<libc::gid_t>()) != 0 {
|
||||
panic!("setgroups[]: {:?}", last_os_error());
|
||||
}
|
||||
|
||||
if keep_privs {
|
||||
// finally set effective uid. saved uid is still 0.
|
||||
if libc::seteuid(uid) != 0 {
|
||||
panic!("libc::seteuid({}): {:?}", uid, last_os_error());
|
||||
}
|
||||
} else {
|
||||
// drop all privs.
|
||||
if libc::setuid(uid) != 0 {
|
||||
panic!("libc::setuid({}): {:?}", uid, last_os_error());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Do we have sufficient privs to switch uids?
|
||||
pub fn have_suid_privs() -> bool {
|
||||
unsafe { libc::geteuid() == 0 }
|
||||
}
|
||||
58
src/tls.rs
Normal file
58
src/tls.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use std::fs::File;
|
||||
use std::io::{self, ErrorKind};
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use rustls_pemfile as pemfile;
|
||||
|
||||
use crate::config::Server;
|
||||
|
||||
pub fn tls_acceptor(cfg: &Server) -> io::Result<TlsAcceptor> {
|
||||
|
||||
// Private key.
|
||||
let pkey_fn = cfg.tls_key.as_ref().ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::NotFound, "config: server: tls_key not set")
|
||||
})?;
|
||||
let pkey_file = File::open(pkey_fn).map_err(|e| {
|
||||
io::Error::new(e.kind(), format!("{}: {}", pkey_fn, e))
|
||||
})?;
|
||||
let mut pkey_file = io::BufReader::new(pkey_file);
|
||||
let pkey = match pemfile::read_one(&mut pkey_file) {
|
||||
Ok(Some(pemfile::Item::RSAKey(pkey))) => PrivateKey(pkey),
|
||||
Ok(Some(pemfile::Item::PKCS8Key(pkey))) => PrivateKey(pkey),
|
||||
Ok(Some(pemfile::Item::ECKey(pkey))) => PrivateKey(pkey),
|
||||
Ok(Some(_)) => return Err(io::Error::new(io::ErrorKind::InvalidData, format!("{}: unknown private key format", pkey_fn))),
|
||||
Ok(None) => return Err(io::Error::new(io::ErrorKind::InvalidData, format!("{}: expected one private key", pkey_fn))),
|
||||
Err(_) => return Err(io::Error::new(io::ErrorKind::InvalidData, format!("{}: invalid data", pkey_fn))),
|
||||
};
|
||||
|
||||
// Certificate.
|
||||
let cert_fn = cfg.tls_cert.as_ref().ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::NotFound, "config: server: tls_cert not set")
|
||||
})?;
|
||||
let cert_file = File::open(cert_fn).map_err(|e| {
|
||||
io::Error::new(e.kind(), format!("{}: {}", cert_fn, e))
|
||||
})?;
|
||||
let mut cert_file = io::BufReader::new(cert_file);
|
||||
let certs = pemfile::certs(&mut cert_file).map_err(|_| {
|
||||
io::Error::new(io::ErrorKind::InvalidData, format!("{}: invalid data", cert_fn))
|
||||
})?;
|
||||
let certs = certs
|
||||
.into_iter()
|
||||
.map(|cert| Certificate(cert.into()))
|
||||
.collect();
|
||||
|
||||
let config = Arc::new(
|
||||
ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, pkey)
|
||||
.map_err(|e| {
|
||||
io::Error::new(ErrorKind::InvalidData, format!("{}/{}: {}", pkey_fn, cert_fn, e))
|
||||
})?
|
||||
).into();
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
135
src/unixuser.rs
Normal file
135
src/unixuser.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use std;
|
||||
use std::ffi::{CStr, OsStr};
|
||||
use std::io;
|
||||
use std::os::unix::ffi::OsStrExt;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use tokio::task::block_in_place;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct User {
|
||||
pub name: String,
|
||||
pub passwd: String,
|
||||
pub gecos: String,
|
||||
pub uid: u32,
|
||||
pub gid: u32,
|
||||
pub groups: Vec<u32>,
|
||||
pub dir: PathBuf,
|
||||
pub shell: PathBuf,
|
||||
}
|
||||
|
||||
unsafe fn cptr_to_osstr<'a>(c: *const libc::c_char) -> &'a OsStr {
|
||||
let bytes = CStr::from_ptr(c).to_bytes();
|
||||
OsStr::from_bytes(&bytes)
|
||||
}
|
||||
|
||||
unsafe fn cptr_to_path<'a>(c: *const libc::c_char) -> &'a Path {
|
||||
Path::new(cptr_to_osstr(c))
|
||||
}
|
||||
|
||||
unsafe fn to_user(pwd: &libc::passwd) -> User {
|
||||
// turn into (unsafe!) rust slices
|
||||
let cs_name = CStr::from_ptr(pwd.pw_name);
|
||||
let cs_passwd = CStr::from_ptr(pwd.pw_passwd);
|
||||
let cs_gecos = CStr::from_ptr(pwd.pw_gecos);
|
||||
let cs_dir = cptr_to_path(pwd.pw_dir);
|
||||
let cs_shell = cptr_to_path(pwd.pw_shell);
|
||||
|
||||
// then turn the slices into safe owned values.
|
||||
User {
|
||||
name: cs_name.to_string_lossy().into_owned(),
|
||||
passwd: cs_passwd.to_string_lossy().into_owned(),
|
||||
gecos: cs_gecos.to_string_lossy().into_owned(),
|
||||
dir: cs_dir.to_path_buf(),
|
||||
shell: cs_shell.to_path_buf(),
|
||||
uid: pwd.pw_uid,
|
||||
gid: pwd.pw_gid,
|
||||
groups: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub fn by_name(name: &str, with_groups: bool) -> Result<User, io::Error> {
|
||||
let mut buf = [0u8; 1024];
|
||||
let mut pwd: libc::passwd = unsafe { std::mem::zeroed() };
|
||||
let mut result: *mut libc::passwd = std::ptr::null_mut();
|
||||
|
||||
let cname = match std::ffi::CString::new(name) {
|
||||
Ok(un) => un,
|
||||
Err(_) => return Err(io::Error::from_raw_os_error(libc::ENOENT)),
|
||||
};
|
||||
let ret = unsafe {
|
||||
libc::getpwnam_r(
|
||||
cname.as_ptr(),
|
||||
&mut pwd as *mut _,
|
||||
buf.as_mut_ptr() as *mut _,
|
||||
buf.len() as libc::size_t,
|
||||
&mut result as *mut _,
|
||||
)
|
||||
};
|
||||
|
||||
if ret != 0 {
|
||||
return Err(io::Error::from_raw_os_error(ret));
|
||||
}
|
||||
if result.is_null() {
|
||||
return Err(io::Error::from_raw_os_error(libc::ENOENT));
|
||||
}
|
||||
let mut user = unsafe { to_user(&pwd) };
|
||||
|
||||
if with_groups {
|
||||
let mut ngroups = (buf.len() / std::mem::size_of::<libc::gid_t>()) as libc::c_int;
|
||||
let ret = unsafe {
|
||||
libc::getgrouplist(
|
||||
cname.as_ptr(),
|
||||
user.gid as libc::gid_t,
|
||||
buf.as_mut_ptr() as *mut _,
|
||||
&mut ngroups as *mut _,
|
||||
)
|
||||
};
|
||||
if ret >= 0 && ngroups > 0 {
|
||||
let mut groups_vec = Vec::with_capacity(ngroups as usize);
|
||||
let groups = unsafe {
|
||||
std::slice::from_raw_parts(buf.as_ptr() as *const libc::gid_t, ngroups as usize)
|
||||
};
|
||||
//
|
||||
// Only supplementary or auxilary groups, filter out primary.
|
||||
//
|
||||
groups_vec.extend(groups.iter().map(|&g| g as u32).filter(|&g| g != user.gid));
|
||||
user.groups = groups_vec;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
/*
|
||||
pub fn by_uid(uid: u32) -> Result<User, io::Error> {
|
||||
let mut buf = [0; 1024];
|
||||
let mut pwd: libc::passwd = unsafe { std::mem::zeroed() };
|
||||
let mut result: *mut libc::passwd = std::ptr::null_mut();
|
||||
|
||||
let ret = unsafe {
|
||||
getpwuid_r(
|
||||
uid,
|
||||
&mut pwd as *mut _,
|
||||
buf.as_mut_ptr(),
|
||||
buf.len() as libc::size_t,
|
||||
&mut result as *mut _,
|
||||
)
|
||||
};
|
||||
if ret == 0 {
|
||||
if result.is_null() {
|
||||
return Err(io::Error::from_raw_os_error(libc::ENOENT));
|
||||
}
|
||||
let p = unsafe { to_user(&pwd) };
|
||||
Ok(p)
|
||||
} else {
|
||||
Err(io::Error::from_raw_os_error(ret))
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
pub async fn by_name_async(name: &str, with_groups: bool) -> Result<User, io::Error> {
|
||||
block_in_place(move || User::by_name(name, with_groups))
|
||||
}
|
||||
}
|
||||
125
src/userfs.rs
Normal file
125
src/userfs.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use std::any::Any;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use webdav_handler::davpath::DavPath;
|
||||
use webdav_handler::fs::*;
|
||||
use webdav_handler::localfs::LocalFs;
|
||||
|
||||
use crate::suid::UgidSwitch;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct UserFs {
|
||||
pub fs: LocalFs,
|
||||
basedir: PathBuf,
|
||||
uid: u32,
|
||||
}
|
||||
|
||||
impl UserFs {
|
||||
pub fn new(
|
||||
dir: impl AsRef<Path>,
|
||||
target_creds: Option<(u32, u32, &[u32])>,
|
||||
public: bool,
|
||||
case_insensitive: bool,
|
||||
macos: bool,
|
||||
) -> Box<UserFs>
|
||||
{
|
||||
// uid is used for quota() calls.
|
||||
let uid = target_creds.as_ref().map(|ugid| ugid.0).unwrap_or(0);
|
||||
|
||||
// set up the LocalFs hooks for uid switching.
|
||||
let switch = UgidSwitch::new(target_creds.clone());
|
||||
let blocking_guard = Box::new(move || Box::new(switch.guard()) as Box<dyn Any>);
|
||||
|
||||
Box::new(UserFs {
|
||||
basedir: dir.as_ref().to_path_buf(),
|
||||
fs: *LocalFs::new_with_fs_access_guard(
|
||||
dir,
|
||||
public,
|
||||
case_insensitive,
|
||||
macos,
|
||||
Some(blocking_guard),
|
||||
),
|
||||
uid: uid,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl DavFileSystem for UserFs {
|
||||
fn metadata<'a>(&'a self, path: &'a DavPath) -> FsFuture<Box<dyn DavMetaData>> {
|
||||
self.fs.metadata(path)
|
||||
}
|
||||
|
||||
fn symlink_metadata<'a>(&'a self, path: &'a DavPath) -> FsFuture<Box<dyn DavMetaData>> {
|
||||
self.fs.symlink_metadata(path)
|
||||
}
|
||||
|
||||
fn read_dir<'a>(
|
||||
&'a self,
|
||||
path: &'a DavPath,
|
||||
meta: ReadDirMeta,
|
||||
) -> FsFuture<FsStream<Box<dyn DavDirEntry>>>
|
||||
{
|
||||
self.fs.read_dir(path, meta)
|
||||
}
|
||||
|
||||
fn open<'a>(&'a self, path: &'a DavPath, options: OpenOptions) -> FsFuture<Box<dyn DavFile>> {
|
||||
self.fs.open(path, options)
|
||||
}
|
||||
|
||||
fn create_dir<'a>(&'a self, path: &'a DavPath) -> FsFuture<()> {
|
||||
self.fs.create_dir(path)
|
||||
}
|
||||
|
||||
fn remove_dir<'a>(&'a self, path: &'a DavPath) -> FsFuture<()> {
|
||||
self.fs.remove_dir(path)
|
||||
}
|
||||
|
||||
fn remove_file<'a>(&'a self, path: &'a DavPath) -> FsFuture<()> {
|
||||
self.fs.remove_file(path)
|
||||
}
|
||||
|
||||
fn rename<'a>(&'a self, from: &'a DavPath, to: &'a DavPath) -> FsFuture<()> {
|
||||
self.fs.rename(from, to)
|
||||
}
|
||||
|
||||
fn copy<'a>(&'a self, from: &'a DavPath, to: &'a DavPath) -> FsFuture<()> {
|
||||
self.fs.copy(from, to)
|
||||
}
|
||||
|
||||
#[cfg(feature = "quota")]
|
||||
fn get_quota<'a>(&'a self) -> FsFuture<(u64, Option<u64>)> {
|
||||
use crate::cache;
|
||||
use fs_quota::*;
|
||||
use futures::future::FutureExt;
|
||||
use std::time::Duration;
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref QCACHE: cache::Cache<PathBuf, FsQuota> = cache::Cache::new().maxage(Duration::new(30, 0));
|
||||
}
|
||||
|
||||
async move {
|
||||
let mut key = self.basedir.clone();
|
||||
key.push(&self.uid.to_string());
|
||||
let r = match QCACHE.get(&key) {
|
||||
Some(r) => {
|
||||
debug!("get_quota for {:?}: from cache", key);
|
||||
r
|
||||
},
|
||||
None => {
|
||||
let path = self.basedir.clone();
|
||||
let uid = self.uid;
|
||||
let r = self
|
||||
.fs
|
||||
.blocking(move || {
|
||||
FsQuota::check(&path, Some(uid)).map_err(|_| FsError::GeneralFailure)
|
||||
})
|
||||
.await?;
|
||||
debug!("get_quota for {:?}: insert to cache", key);
|
||||
QCACHE.insert(key, r)
|
||||
},
|
||||
};
|
||||
Ok((r.bytes_used, r.bytes_limit))
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user