import re
import time
from boto.dynamodb2.exceptions import ItemNotFound
from cryptography.fernet import InvalidToken
from cryptography.exceptions import InvalidSignature
from marshmallow import (
Schema,
fields,
pre_load,
post_load,
validates,
validates_schema,
)
from marshmallow_polyfield import PolyField
from marshmallow.validate import OneOf
from twisted.logger import Logger # noqa
from twisted.internet.defer import Deferred # noqa
from twisted.internet.defer import maybeDeferred
from twisted.internet.threads import deferToThread
from typing import ( # noqa
Any,
Dict,
Optional
)
from autopush.crypto_key import CryptoKey
from autopush.db import dump_uaid, hasher
from autopush.exceptions import (
InvalidRequest,
InvalidTokenException,
VapidAuthException,
)
from autopush.settings import AutopushSettings # noqa
from autopush.utils import (
base64url_encode,
extract_jwt,
ms_time,
WebPushNotification,
normalize_id,
parse_auth_header,
)
from autopush.web.base import (
threaded_validate,
BaseWebHandler,
PREF_SCHEME,
)
MAX_TTL = 60 * 60 * 24 * 60
# Base64 URL validation
VALID_BASE64_URL = re.compile(r'^[0-9A-Za-z\-_]+=*$')
class WebPushSubscriptionSchema(Schema):
uaid = fields.UUID(required=True)
chid = fields.UUID(required=True)
public_key = fields.Raw(missing=None)
@pre_load
def extract_subscription(self, d):
try:
result = self.context["settings"].parse_endpoint(
token=d["token"],
version=d["api_ver"],
ckey_header=d["ckey_header"],
auth_header=d["auth_header"],
)
except (VapidAuthException):
raise InvalidRequest("missing authorization header",
status_code=401, errno=109)
except (InvalidTokenException, InvalidToken):
raise InvalidRequest("invalid token", status_code=404, errno=102)
return result
@validates_schema(skip_on_field_errors=True)
def validate_uaid_month_and_chid(self, d):
settings = self.context["settings"] # type: AutopushSettings
try:
result = settings.router.get_uaid(d["uaid"].hex)
except ItemNotFound:
raise InvalidRequest("UAID not found", status_code=410, errno=103)
if result.get("router_type") not in ["webpush", "gcm", "apns", "fcm"]:
raise InvalidRequest("Wrong URL for user", errno=108)
if (result.get("router_type") in ["gcm", "fcm"]
and 'senderID' not in result.get('router_data',
{}).get("creds", {})):
# Make sure we note that this record is bad.
result['critical_failure'] = \
result.get('critical_failure', "Missing SenderID")
settings.router.register_user(result)
if result.get("critical_failure"):
raise InvalidRequest("Critical Failure: %s" %
result.get("critical_failure"),
status_code=410,
errno=105)
if result["router_type"] == "webpush":
self._validate_webpush(d, result)
# Propagate the looked up user data back out
d["user_data"] = result
def _validate_webpush(self, d, result):
settings = self.context["settings"] # type: AutopushSettings
log = self.context["log"] # type: Logger
channel_id = normalize_id(d["chid"])
uaid = result["uaid"]
if 'current_month' not in result:
log.info(format="Dropping User", code=102,
uaid_hash=hasher(uaid),
uaid_record=dump_uaid(result))
settings.router.drop_user(uaid)
raise InvalidRequest("No such subscription", status_code=410,
errno=106)
month_table = result["current_month"]
if month_table not in settings.message_tables:
log.info(format="Dropping User", code=103,
uaid_hash=hasher(uaid),
uaid_record=dump_uaid(result))
settings.router.drop_user(uaid)
raise InvalidRequest("No such subscription", status_code=410,
errno=106)
exists, chans = settings.message_tables[month_table].all_channels(
uaid=uaid)
if (not exists or channel_id.lower() not
in map(lambda x: normalize_id(x), chans)):
log.info("Unknown subscription: {channel_id}",
channel_id=channel_id)
raise InvalidRequest("No such subscription", status_code=410,
errno=106)
class WebPushBasicHeaderSchema(Schema):
authorization = fields.String()
ttl = fields.Integer(required=False, missing=None)
topic = fields.String(required=False, missing=None)
api_ver = fields.String()
@validates('topic')
def validate_topic(self, value):
if value is None:
return True
if len(value) > 32:
raise InvalidRequest("Topic must be no greater than 32 "
"characters", errno=113)
if not VALID_BASE64_URL.match(value):
raise InvalidRequest("Topic must be URL and Filename safe Base"
"64 alphabet", errno=113)
@post_load
def cap_ttl(self, d):
if 'ttl' in d:
d["ttl"] = min(d["ttl"], MAX_TTL)
class WebPushCrypto01HeaderSchema(Schema):
"""Validates WebPush Message Encryption
Uses draft-ietf-webpush-encryption-01 rules for validation.
"""
content_encoding = fields.String(
required=True,
load_from="content-encoding",
validate=OneOf(["aesgcm128"])
)
encryption = fields.String(required=True)
encryption_key = fields.String(
required=True,
load_from="encryption-key"
)
crypto_key = fields.String(load_from="crypto-key")
@validates("encryption")
def validate_encryption(self, value):
"""Must contain a salt value"""
salt = CryptoKey.parse_and_get_label(value, "salt")
if not salt or not VALID_BASE64_URL.match(salt):
raise InvalidRequest("Invalid salt value in Encryption header",
status_code=400,
errno=110)
@validates("crypto_key")
def validate_crypto_key(self, value):
"""Must not contain a dh value"""
dh = CryptoKey.parse_and_get_label(value, "dh")
if dh:
raise InvalidRequest(
"dh value in Crypto-Key header not valid for 01 or earlier "
"webpush-encryption",
status_code=400,
errno=110,
)
@validates("encryption_key")
def validate_encryption_key(self, value):
"""Must contain a dh value"""
dh = CryptoKey.parse_and_get_label(value, "dh")
if not dh or not VALID_BASE64_URL.match("dh"):
raise InvalidRequest("Invalid dh value in Encryption-Key header",
status_code=400,
errno=110)
class WebPushCrypto04HeaderSchema(Schema):
"""Validates WebPush Message Encryption
Uses draft-ietf-webpush-encryption-04 rules for validation.
"""
content_encoding = fields.String(
required=True,
load_from="content-encoding",
validate=OneOf(["aesgcm"])
)
encryption = fields.String(required=True)
crypto_key = fields.String(
load_from="crypto-key",
)
@validates("encryption")
def validate_encryption(self, value):
"""Must contain a salt value"""
salt = CryptoKey.parse_and_get_label(value, "salt")
if not salt or not VALID_BASE64_URL.match(salt):
raise InvalidRequest("Invalid salt value in Encryption header",
status_code=400,
errno=110)
@validates("crypto_key")
def validate_crypto_key(self, value):
"""Must contain a dh value"""
dh = CryptoKey.parse_and_get_label(value, "dh")
if not dh or not VALID_BASE64_URL.match("dh"):
raise InvalidRequest("Invalid dh value in Encryption-Key header",
status_code=400,
errno=110)
@validates_schema(pass_original=True)
def reject_encryption_key(self, data, original_data):
if "encryption-key" in original_data:
raise InvalidRequest(
"Encryption-Key header not valid for 02 or later "
"webpush-encryption",
status_code=400,
errno=110,
)
class WebPushInvalidContentEncodingSchema(Schema):
"""Returned to raise an Invalid Content-encoding error"""
@validates_schema
def invalid_content_encoding(self, d):
raise InvalidRequest(
"Unknown Content-Encoding",
status_code=400,
errno=110
)
def conditional_crypto_deserialize(object_dict, parent_object_dict):
"""Return the WebPush Crypto Schema if there's a data payload"""
if parent_object_dict.get("body"):
encoding = object_dict.get("content-encoding")
# Validate the crypto headers appropriately
if encoding == "aesgcm128":
return WebPushCrypto01HeaderSchema()
elif encoding == "aesgcm":
return WebPushCrypto04HeaderSchema()
else:
return WebPushInvalidContentEncodingSchema()
else:
return Schema()
class WebPushRequestSchema(Schema):
subscription = fields.Nested(WebPushSubscriptionSchema,
load_from="token_info")
headers = fields.Nested(WebPushBasicHeaderSchema)
crypto_headers = PolyField(
load_from="headers",
deserialization_schema_selector=conditional_crypto_deserialize,
)
body = fields.Raw()
token_info = fields.Raw()
vapid_version = fields.String(required=False, missing=None)
@validates('body')
def validate_data(self, value):
max_data = self.context["settings"].max_data
if value and len(value) > max_data:
raise InvalidRequest(
"Data payload must be smaller than {}".format(max_data),
errno=104,
)
@pre_load
def token_prep(self, d):
d["token_info"] = dict(
api_ver=d["path_kwargs"].get("api_ver"),
token=d["path_kwargs"].get("token"),
ckey_header=d["headers"].get("crypto-key", ""),
auth_header=d["headers"].get("authorization", ""),
)
return d
def validate_auth(self, d):
auth = d["headers"].get("authorization")
needs_auth = d["token_info"]["api_ver"] == "v2"
if not needs_auth and not auth:
return
try:
vapid_auth = parse_auth_header(auth)
token = vapid_auth['t']
d["vapid_version"] = "draft{:0>2}".format(
vapid_auth['version'])
if vapid_auth['version'] == 2:
public_key = vapid_auth['k']
else:
public_key = d["subscription"].get("public_key")
jwt = extract_jwt(token, public_key)
except (KeyError, ValueError, InvalidSignature, TypeError,
VapidAuthException):
raise InvalidRequest("Invalid Authorization Header",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
if "exp" not in jwt:
raise InvalidRequest("Invalid bearer token: No expiration",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
try:
jwt_expires = int(jwt['exp'])
except ValueError:
raise InvalidRequest("Invalid bearer token: Invalid expiration",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
now = time.time()
jwt_has_expired = now > jwt_expires
if jwt_has_expired:
raise InvalidRequest("Invalid bearer token: Auth expired",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
jwt_too_far_in_future = (jwt_expires - now) > (60*60*24)
if jwt_too_far_in_future:
raise InvalidRequest("Invalid bearer token: Auth > 24 hours in "
"the future",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
jwt_crypto_key = base64url_encode(public_key)
d["jwt"] = dict(jwt_crypto_key=jwt_crypto_key, jwt_data=jwt)
@post_load
def fixup_output(self, d):
# Verify authorization
# Note: This has to be done here, since schema validation takes place
# before nested schemas, and in this case we need all the nested
# schema logic to run first.
self.validate_auth(d)
# Merge crypto headers back in
if d["crypto_headers"]:
d["headers"].update(
{k.replace("_", "-"): v for k, v in
d["crypto_headers"].items()}
)
# Base64-encode data for Web Push
d["body"] = base64url_encode(d["body"])
# Set the notification based on the validated request schema data
d["notification"] = WebPushNotification.from_webpush_request_schema(
data=d, fernet=self.context["settings"].fernet,
legacy=self.context["settings"]._notification_legacy,
)
return d
[docs]class WebPushHandler(BaseWebHandler):
cors_methods = "POST"
cors_request_headers = ("content-encoding", "encryption",
"crypto-key", "ttl",
"encryption-key", "content-type",
"authorization")
cors_response_headers = ("location", "www-authenticate")
@threaded_validate(WebPushRequestSchema)
def post(self,
subscription, # type: Dict[str, Any]
notification, # type: WebPushNotification
jwt=None, # type: Optional[Dict[str, str]]
**kwargs # type: Any
):
# type: (...) -> Deferred
# Store Vapid info if present
if jwt:
self.metrics.increment("updates.vapid.{}".format(
kwargs.get('vapid_version'))
)
self._client_info["jwt_crypto_key"] = jwt["jwt_crypto_key"]
for i in jwt["jwt_data"]:
self._client_info["jwt_" + i] = jwt["jwt_data"][i]
user_data = subscription["user_data"]
encoding = ''
if notification.data and notification.headers:
encoding = notification.headers.get('encoding', '')
self._client_info.update(
message_id=notification.message_id,
uaid_hash=hasher(user_data.get("uaid")),
channel_id=user_data.get("chid"),
router_key=user_data["router_type"],
message_size=len(notification.data or ""),
message_ttl=notification.ttl,
version=notification.version,
encoding=encoding,
)
router = self.ap_settings.routers[user_data["router_type"]]
self._router_time = time.time()
d = maybeDeferred(router.route_notification, notification, user_data)
d.addCallback(self._router_completed, user_data, "")
d.addErrback(self._router_fail_err)
d.addErrback(self._response_err)
return d
[docs] def _router_completed(self, response, uaid_data, warning=""):
"""Called after router has completed successfully"""
# Log the time taken for routing
self._timings["route_time"] = time.time() - self._router_time
# Were we told to update the router data?
time_diff = time.time() - self._start_time
if response.router_data is not None:
if not response.router_data:
# An empty router_data object indicates that the record should
# be deleted. There is no longer valid route information for
# this record.
self.log.info(format="Dropping User", code=100,
uaid_hash=hasher(uaid_data["uaid"]),
uaid_record=dump_uaid(uaid_data),
client_info=self._client_info)
d = deferToThread(self.ap_settings.router.drop_user,
uaid_data["uaid"])
d.addCallback(lambda x: self._router_response(response))
return d
# The router data needs to be updated to include any changes
# requested by the bridge system
uaid_data["router_data"] = response.router_data
# set the AWS mandatory data
uaid_data["connected_at"] = ms_time()
d = deferToThread(self.ap_settings.router.register_user,
uaid_data)
response.router_data = None
d.addCallback(lambda x: self._router_completed(
response,
uaid_data,
warning))
return d
else:
# No changes are requested by the bridge system, proceed as normal
if response.status_code == 200 or response.logged_status == 200:
self.log.info(format="Successful delivery",
client_info=self._client_info)
elif response.status_code == 202 or response.logged_status == 202:
self.log.info(
format="Router miss, message stored.",
client_info=self._client_info)
self.metrics.timing("updates.handled", duration=time_diff)
response.response_body = (
response.response_body + " " + warning).strip()
self._router_response(response)