fix: encoding of cache

This commit is contained in:
Hazel 2024-04-26 14:04:44 +02:00
parent e77afa584b
commit 25eceb727b
2 changed files with 21 additions and 6 deletions

View File

@ -1,6 +1,6 @@
import json
from pathlib import Path
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import List, Optional
from functools import lru_cache
@ -17,6 +17,8 @@ class CacheAttribute:
created: datetime
expires: datetime
additional_info: dict = field(default_factory=dict)
@property
def id(self):
@ -32,6 +34,12 @@ class CacheAttribute:
return self.__dict__ == other.__dict__
@dataclass
class CacheResult:
content: bytes
attribute: CacheAttribute
class Cache:
def __init__(self, module: str, logger: logging.Logger):
self.module = module
@ -100,7 +108,7 @@ class Cache:
return True
def set(self, content: bytes, name: str, expires_in: float = 10, module: str = ""):
def set(self, content: bytes, name: str, expires_in: float = 10, module: str = "", additional_info: dict = None):
"""
:param content:
:param module:
@ -111,6 +119,7 @@ class Cache:
if name == "":
return
additional_info = additional_info or {}
module = self.module if module == "" else module
module_path = self._init_module(module)
@ -128,7 +137,7 @@ class Cache:
self.logger.debug(f"writing cache to {cache_path}")
content_file.write(content)
def get(self, name: str) -> Optional[bytes]:
def get(self, name: str) -> Optional[CacheResult]:
path = fit_to_file_system(Path(self._dir, self.module, name), hidden_ok=True)
if not path.is_file():
@ -140,7 +149,7 @@ class Cache:
return
with path.open("rb") as f:
return f.read()
return CacheResult(content=f.read(), attribute=existing_attribute)
def clean(self):
keep = set()

View File

@ -133,7 +133,9 @@ class Connection:
if self.cache.get(name) is not None and no_update_if_valid_exists:
return
self.cache.set(r.content, name, expires_in=kwargs.get("expires_in", self.cache_expiring_duration), **n_kwargs)
self.cache.set(r.content, name, expires_in=kwargs.get("expires_in", self.cache_expiring_duration), additional_info={
"encoding", r.encoding,
}, **n_kwargs)
def request(
self,
@ -189,10 +191,14 @@ class Connection:
request_trace(f"{trace_string}\t[cached]")
with responses.RequestsMock() as resp:
body = cached.content
if "encoding" in cached.additional_info:
body = body.decode(cached.additional_info["encoding"])
resp.add(
method=method,
url=request_url,
body=cached,
body=body,
)
return requests.request(method=method, url=url, timeout=timeout, headers=headers, **kwargs)