Commit 7379cf3e authored by Martin van Es's avatar Martin van Es
Browse files

Respect validUntil and cacheDuration

parent 1a78c08c
......@@ -5,6 +5,7 @@ from flask import Flask, Response
from urllib.parse import unquote
from dateutil import parser, tz
from datetime import datetime
from isoduration import parse_duration
from utils import read_config, hasher, Entity
......@@ -30,27 +31,35 @@ def serve(domain, eid):
cached[domain] = cached.get(domain, {})
if entityID in cached[domain]:
if cached[domain][entityID].valid_until > datetime.now(tz.tzutc()):
if cached[domain][entityID].expires > datetime.now(tz.tzutc()):
print(f"serve {entityID}")
return cached[domain][entityID].md
else:
print(f"request {entityID}")
data = requests.get(f"{config[domain]['signer']}/{domain}/entities/{{sha1}}{entityID}").text
try:
parsed = ET.fromstring(data)
validUntil = parsed.get('validUntil')
# cacheDuration = parsed.get('cacheDuration')
cached_entity = Entity()
cached_entity.md = data
cached_entity.valid_until = parser.isoparse(validUntil)
print(f"request {entityID}")
data = requests.get(f"{config[domain]['signer']}/{domain}"
f"/entities/{{sha1}}{entityID}").text
try:
root = ET.fromstring(data)
validUntil = root.get('validUntil')
cacheDuration = root.get('cacheDuration')
cached_entity = Entity()
cached_entity.md = data
cached_entity.valid_until = parser.isoparse(validUntil)
cached_entity.cache_duration = parse_duration(cacheDuration)
cached_entity.expires = min(datetime.now(tz.tzutc())
+ cached_entity.cache_duration,
cached_entity.valid_until)
if cached_entity.valid_until > datetime.now(tz.tzutc()):
cached[domain][entityID] = cached_entity
except ET.XMLSyntaxError:
data = "No valid metadata\n"
response.headers['Content-type'] = "text/html"
response.status = 404
else:
raise KeyError
except Exception:
data = "No valid metadata\n"
response.headers['Content-type'] = "text/html"
response.status = 404
response.data = data
return response
response.data = data
return response
app.run(host='0.0.0.0', port=80)
app.run(host='127.0.0.1', port=5002)
......@@ -6,7 +6,9 @@ config = read_config('mdserver.yaml')
app = Flask(__name__)
server = Server()
@app.route('/<domain>/entities/<path:entity_id>', methods=['GET'])
@app.route('/<domain>/entities/<path:entity_id>',
methods=['GET'])
def serve(domain, entity_id):
response = Response()
response.headers['Content-Type'] = "application/samlmetadata+xml"
......@@ -31,4 +33,4 @@ for domain, values in config.items():
if __name__ == "__main__":
app.run(host='0.0.0.0', port=5001, debug=False)
app.run(host='127.0.0.1', port=5001, debug=False)
......@@ -54,8 +54,7 @@ class Resource:
found = 0
removed = 0
old_idps = self.mdfiles[mdfile].copy()
tree = ET.ElementTree(file=mdfile)
root = tree.getroot()
root = ET.ElementTree(file=mdfile).getroot()
ns = root.nsmap.copy()
ns['xml'] = 'http://www.w3.org/XML/1998/namespace'
validUntil = root.get('validUntil')
......@@ -73,6 +72,8 @@ class Resource:
entity.md = entity_descriptor
entity.valid_until = valid_until
entity.cache_duration = cache_duration
entity.expires = min(datetime.now(tz.tzutc()) + cache_duration,
valid_until)
self.idps[sha1] = entity
self.__dict__.pop(sha1, None)
if sha1 in self.mdfiles[mdfile]:
......@@ -96,27 +97,31 @@ class Resource:
else:
sha1 = hasher(entityID)
data = None
if sha1 in self.__dict__:
signed_entity = self.__dict__[sha1]
if signed_entity.valid_until > datetime.now(tz.tzutc()):
if signed_entity.expires > datetime.now(tz.tzutc()):
data = self.__dict__[sha1].md
elif sha1 in self.idps:
if data is None and sha1 in self.idps:
try:
print(f"sign {sha1}")
valid_until = self.idps[sha1].valid_until
if valid_until > datetime.now(tz.tzutc()):
signed_element = self.signer(self.idps[sha1].md)
signed_xml = ET.tostring(signed_element, pretty_print=True).decode()
signed_xml = ET.tostring(signed_element,
pretty_print=True).decode()
signed_entity = Entity()
signed_entity.md = signed_xml
signed_entity.valid_until = self.idps[sha1].valid_until
signed_entity.expires = (datetime.now(tz.tzutc())
+ self.idps[sha1].cache_duration)
self.__dict__[sha1] = signed_entity
data = signed_xml
else:
raise KeyError
except Exception as e:
print(sha1)
print(f" {e}")
else:
raise KeyError
print(f"serve {sha1}")
return data
......@@ -135,7 +140,8 @@ class Server:
def __init__(self):
self.watch_manager = pyinotify.WatchManager()
self.event_notifier = pyinotify.ThreadedNotifier(self.watch_manager, EventProcessor(self))
self.event_notifier = pyinotify.ThreadedNotifier(self.watch_manager,
EventProcessor(self))
self.event_notifier.start()
def add_watch(self, domain, location):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment