Commit 10df9b45 authored by Martin van Es's avatar Martin van Es
Browse files

Move all entities logic to Resource class

parent 32ee2257
......@@ -19,27 +19,15 @@ server = Server()
strict_slashes=False,
methods=['GET'])
def serve_all(realm):
dirty = False
for key, resource in server.items():
if resource.dirty:
dirty = True
resource.dirty = False
print(f"all in {realm}")
response = Response()
response.headers['Content-Type'] = "application/samlmetadata+xml"
response.headers['Content-Disposition'] = "filename = \"metadata.xml\""
if server.all_entities is not None and not dirty:
print("cache all")
data = server.all_entities
response.data = data.md
else:
print("sign all")
data = server[realm].all_entities()
response.data = data.md
server.all_entities = data
max_age = int((data.valid_until - datetime.now(tz.tzutc())).total_seconds())
data = server[realm].all_entities()
response.data = data.md
max_age = int((data.valid_until -
datetime.now(tz.tzutc())).total_seconds())
response.headers['Cache-Control'] = f"max-age={max_age}"
response.headers['Last-Modified'] = formatdate(timeval=mktime(data.last_modified.timetuple()),
......
......@@ -62,6 +62,7 @@ class Server(dict):
class Resource:
watch_list = {}
dirty = False
all_cache = None
def __init__(self, location, signer):
self.idps = {}
......@@ -120,7 +121,7 @@ class Resource:
cache_duration = parse_duration(cacheDuration)
last_modified = datetime.now(tz.tzutc())
if valid_until > datetime.now(tz.tzutc()):
self.dirty = True
self.all_cache = None
for entity_descriptor in root.findall('md:EntityDescriptor', ns):
entityID = entity_descriptor.attrib.get('entityID', 'none')
sha1 = hasher(entityID)
......@@ -191,28 +192,35 @@ class Resource:
return data
def all_entities(self):
data = MData()
ns = NSMAP
root = ET.Element(f"{{{MD_NAMESPACE}}}EntitiesDescriptor",
nsmap=ns)
# We are going to minimize expires, so set to some inf value
valid_until = (datetime.now(tz.tzutc()) +
timedelta(days=365))
cache_duration = parse_duration("P1D")
for sha1, entity in self.idps.items():
valid_until = min(valid_until, entity.valid_until)
cache_duration = min(cache_duration, entity.cache_duration)
ET.strip_attributes(entity.md, 'validUntil', 'cacheDuration')
root.append(entity.md)
vu_zulu = str(valid_until).replace('+00:00', 'Z')
root.set('validUntil', vu_zulu)
root.set('cacheDuration', duration_isoformat(cache_duration))
last_modified = datetime.now(tz.tzutc())
signed_root = self.signer(root)
data.md = ET.tostring(signed_root, pretty_print=True)
data.valid_until = valid_until
data.last_modified = last_modified
if self.all_cache is not None:
print("cache all")
data = self.all_cache
else:
print("sign all")
data = MData()
ns = NSMAP
root = ET.Element(f"{{{MD_NAMESPACE}}}EntitiesDescriptor",
nsmap=ns)
# We are going to minimize expires, so set to some inf value
valid_until = (datetime.now(tz.tzutc()) +
timedelta(days=365))
cache_duration = parse_duration("P1D")
for sha1, entity in self.idps.items():
valid_until = min(valid_until, entity.valid_until)
cache_duration = min(cache_duration, entity.cache_duration)
ET.strip_attributes(entity.md, 'validUntil', 'cacheDuration')
root.append(entity.md)
vu_zulu = str(valid_until).replace('+00:00', 'Z')
root.set('validUntil', vu_zulu)
root.set('cacheDuration', duration_isoformat(cache_duration))
last_modified = datetime.now(tz.tzutc())
signed_root = self.signer(root)
data.md = ET.tostring(signed_root, pretty_print=True)
data.valid_until = valid_until
data.last_modified = last_modified
self.all_cache = data
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment