# This file is part of Lisien, a framework for life simulation games.
# Copyright (c) Zachary Spector, public@zacharyspector.com
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Common classes for collections in lisien
Notably includes wrappers for mutable objects, allowing them to be stored in
the database. These simply store the new value.
Most of these are subclasses of :class:`blinker.Signal`, so you can listen
for changes using the ``connect(..)`` method.
"""
from __future__ import annotations
import ast
import base64
import importlib.util
import json
import os
import sys
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import MutableMapping
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass
from functools import cached_property
from hashlib import blake2b
from inspect import getsource
from types import FunctionType, MethodType
from typing import TYPE_CHECKING, Callable, Iterator, TypeVar
import networkx as nx
from blinker import Signal
from .types import (
AbstractEngine,
AbstractFunctionStore,
ActionFunc,
ActionFuncName,
CharName,
KeyHint,
PrereqFunc,
PrereqFuncName,
Stat,
TriggerFunc,
TriggerFuncName,
UniversalKey,
Value,
ValueHint,
sort_set,
)
from .util import dedent_source, getatt
from .wrap import wrapval
if TYPE_CHECKING:
from .character import Character
# 0x241d is the group separator
# 0x241e is the record separator
# per Unicode 1.1
GROUP_SEP = chr(0x241D).encode()
REC_SEP = chr(0x241E).encode()
class AbstractLanguageDescriptor(Signal, ABC):
@abstractmethod
def _get_language(self, inst: StringStore) -> str:
pass
@abstractmethod
def _set_language(self, inst: StringStore, val: str) -> None:
pass
def __get__(self, instance: StringStore, owner=None):
return self._get_language(instance)
def __set__(self, inst: StringStore, val: str):
self._set_language(inst, val)
self.send(inst, language=val)
class LanguageDescriptor(AbstractLanguageDescriptor):
def _get_language(self, inst: StringStore) -> str:
return inst._current_language
def _set_language(self, inst, lang):
if lang != inst._current_language:
inst._switch_language(lang)
if (
not getattr(inst.engine, "_worker", False)
and inst.engine.eternal["language"] != lang
):
inst.engine.eternal["language"] = lang
inst._current_language = lang
class TamperEvidentDict[_K, _V](dict[_K, _V]):
tampered: bool
def __init__(self, data: list[tuple[_K, _V]] | dict[_K, _V] = ()):
self.tampered = False
super().__init__(data)
def __setitem__(self, key: _K, value: _V) -> None:
self.tampered = True
super().__setitem__(key, value)
def __delitem__(self, key: _K) -> None:
self.tampered = True
super().__delitem__(key)
class ChangeTrackingDict[_K, _V](UserDict[_K, _V]):
def __init__(self, data: list[tuple[_K, _V]] | dict[_K, _V] = ()):
self.changed = {}
super().__init__(data)
def apply_changes(self) -> None:
self.data.update(self.changed)
self.changed.clear()
def copy(self) -> dict[_K, _V]:
ret = {}
ret.update(self.data)
ret.update(self.changed)
return ret
def clear(self) -> None:
self.data.clear()
self.changed.clear()
def __contains__(self, item: _K) -> bool:
return item in self.changed or item in self.data
def __iter__(self) -> Iterator[_K]:
yield from self.changed
yield from self.data
def __len__(self) -> int:
return len(self.changed) + len(self.data)
def __getitem__(self, item: _K) -> _V:
if item in self.changed:
return self.changed[item]
return self.data[item]
def __setitem__(self, key: _K, value: _V) -> None:
self.changed[key] = value
def __delitem__(self, key: _K) -> None:
if key in self.changed:
del self.changed[key]
if key in self.data:
del self.data[key]
else:
del self.data[key]
[docs]
class StringStore(MutableMapping[str, str], Signal):
language = LanguageDescriptor()
_store = "strings"
def __init__(
self,
engine_or_string_dict: AbstractEngine | dict,
prefix: str | None,
lang="eng",
):
super().__init__()
if isinstance(engine_or_string_dict, dict):
self._prefix = None
self._current_language = lang
if lang in engine_or_string_dict and isinstance(
engine_or_string_dict[lang], dict
):
self._languages = engine_or_string_dict
else:
self._languages = {lang: engine_or_string_dict}
else:
self.engine = engine_or_string_dict
self._languages = {lang: TamperEvidentDict()}
self._prefix = prefix
self._current_language = lang
self._switch_language(lang)
def _switch_language(self, lang: str) -> None:
"""Write the current language to disk, and load the new one if available"""
if self._prefix is None:
if lang not in self._languages:
self._languages[lang] = TamperEvidentDict()
return
try:
with open(os.path.join(self._prefix, lang + ".json"), "r") as inf:
self._languages[lang] = TamperEvidentDict(json.load(inf))
except FileNotFoundError:
self._languages[lang] = TamperEvidentDict()
assert self._current_language in self._languages
if getattr(self._languages[self._current_language], "tampered", False):
with open(
os.path.join(self._prefix, self._current_language + ".json"),
"w",
) as outf:
json.dump(
self._languages[self._current_language],
outf,
indent=4,
sort_keys=True,
)
self._languages[self._current_language].tampered = False
def __iter__(self) -> Iterator[str]:
return iter(self._languages[self._current_language])
def __len__(self) -> int:
return len(self._languages[self._current_language])
def __getitem__(self, k: str) -> str:
return self._languages[self._current_language][k]
def __setitem__(self, k: str, v: str) -> None:
"""Set the value of a string for the current language."""
self._languages[self._current_language][k] = v
self.send(self, key=k, val=v)
def __delitem__(self, k: str) -> None:
"""Delete the string from the current language, and remove it from the
cache.
"""
del self._languages[self._current_language][k]
self.send(self, key=k, val=None)
def lang_items(self, lang: str | None = None) -> Iterator[tuple[str, str]]:
"""Yield pairs of (id, string) for the given language."""
if (
self._prefix is not None
and lang is not None
and self._current_language != lang
):
with open(os.path.join(self._prefix, lang + ".json"), "r") as inf:
self._languages[lang] = TamperEvidentDict(json.load(inf))
yield from self._languages[lang or self._current_language].items()
def save(self, reimport: bool = False) -> None:
if self._prefix is None:
return
if not os.path.exists(self._prefix):
os.mkdir(self._prefix)
for lang, d in self._languages.items():
if not d.tampered:
continue
with open(
os.path.join(self._prefix, lang + ".json"),
"w",
) as outf:
json.dump(
self._languages[lang],
outf,
indent=4,
sort_keys=True,
)
d.tampered = False
if reimport:
with open(
os.path.join(self._prefix, self._current_language + ".json"),
"r",
) as inf:
self._languages[self._current_language] = TamperEvidentDict(
json.load(inf)
)
def blake2b(self) -> bytes:
the_hash = blake2b()
for k, v in self.items():
the_hash.update(k.encode())
the_hash.update(GROUP_SEP)
the_hash.update(v.encode())
the_hash.update(REC_SEP)
return the_hash.digest()
@dataclass
class CodeHasher(ast.NodeVisitor):
# What if users want to subclass Place or whatever, and use type
# annotations to decide what rules to run on which of their subclasses?
# Well, that's too complicated.
_indent: int = 0
_updated: bool = False
_in_try_star: bool = False
@cached_property
def _blake2b(self):
return blake2b()
@cached_property
def pack(self):
from .facade import EngineFacade
return EngineFacade(None).pack
def update(self, data: bytes) -> None:
self._blake2b.update(data)
self._updated = True
def visit(self, node):
if self._indent > 0:
self.update(b"\t" * self._indent)
super().visit(node)
@contextmanager
def block(self, extra: bytes = b""):
self.update(b":")
if extra:
self.update(extra)
self._indent += 1
yield
self._indent -= 1
def traverse(self, node: ast.AST | list[ast.stmt]):
if isinstance(node, list):
for item in node:
self.traverse(item)
else:
super().visit(node)
def visit_Constant(self, node: ast.Constant) -> None:
v = node.value
if isinstance(v, tuple):
with self.delimit(b"(", b")"):
self.items_view(self._write_constant, v)
elif v is ...:
self.update(b"...")
else:
self._write_constant(v)
def maybe_semicolon(self):
if self._updated:
self.update(b";")
def maybe_newline(self):
if self._updated:
self.update(b"\n")
def fill(self, text: bytes = b"", allow_semicolon: bool = True):
if self._indent == 0 and allow_semicolon:
self.maybe_semicolon()
self.update(text)
else:
self.maybe_newline()
self.update(b"\t" * self._indent + text)
def visit_FunctionDef(self, node):
for decorator in node.decorator_list:
self.update(b"@")
self.hash_expr(decorator)
self.update(b"def ")
self.update(node.name.encode())
with self.delimit(b"(", b")"):
self.traverse(node.args)
# we don't care about annotations
with self.block():
self.traverse(node.body)
def visit_AsyncFunctionDef(self, node):
raise TypeError("Lisien isn't an event loop")
def visit_comprehension(self, node: ast.comprehension) -> None:
if node.is_async:
raise TypeError("Async not supported", node)
self.update(b" for ")
self.traverse(node.target)
for iffy in node.ifs:
self.update(b" if ")
self.traverse(iffy)
def visit_Name(self, node: ast.Name) -> None:
self.update(node.id.encode())
boolops = {"And": b"and", "Or": b"or"}
def visit_BoolOp(self, node):
with self.parens():
for c in node.values:
self.visit(c)
self.update(self.boolops[node.op.__class__.__name__])
def visit_NamedExpr(self, node):
with self.parens():
self.traverse(node.target)
self.update(b" := ")
self.traverse(node.value)
binop = {
"Add": b"+",
"Sub": b"-",
"Mult": b"*",
"MatMult": b"@",
"Div": b"/",
"Mod": b"%",
"LShift": b"<<",
"RShift": b">>",
"BitOr": b"|",
"BitXor": b"^",
"BitAnd": b"&",
"FloorDiv": b"//",
"Pow": b"**",
}
def visit_BinOp(self, node):
opstr = self.binop[node.op.__class__.__name__]
with self.parens():
self.traverse(node.left)
self.update(b" " + opstr + b" ")
self.traverse(node.right)
cmpops = {
"Eq": b"==",
"NotEq": b"!=",
"Lt": b"<",
"LtE": b"<=",
"Gt": b">",
"GtE": b"b>=",
"Is": b"is",
"IsNot": b"is not",
"In": b"in",
"NotIn": b"not in",
}
@contextmanager
def parens(self):
self.update(b"(")
yield
self.update(b")")
def visit_Compare(self, node):
with self.parens():
self.traverse(node.left)
for o, e in zip(node.ops, node.comparators):
self.update(b" " + self.cmpops[o.__class__.__name__] + b" ")
unop = {"Invert": b"~", "Not": b"not ", "UAdd": b"+", "USub": b"-"}
def visit_UnaryOp(self, node):
opstr = self.unop[node.op.__class__.__name__]
with self.parens():
self.update(opstr)
self.traverse(node.operand)
def visit_Lambda(self, node):
self.update(b"lambda ")
for arg in node.args.posonlyargs:
self.update(arg.arg.encode())
self.update(b",")
self.update(b": ")
self.traverse(node.body)
def visit_IfExp(self, node):
with self.parens():
self.traverse(node.body)
self.update(b" if ")
self.traverse(node.test)
self.update(b" else ")
self.traverse(node.orelse)
def visit_If(self, node):
self.fill(b"if ", allow_semicolon=False)
self.traverse(node.test)
with self.block():
self.traverse(node.body)
while (
node.orelse
and len(node.orelse) == 1
and isinstance(node.orelse[0], ast.If)
):
node = node.orelse[0]
self.fill(b"elif ", allow_semicolon=False)
self.traverse(node.test)
with self.block():
self.traverse(node.body)
if node.orelse:
self.fill(b"else", allow_semicolon=False)
with self.block():
self.traverse(node.orelse)
@contextmanager
def delimit(self, left: bytes, right: bytes):
self.update(left)
yield
self.update(right)
def visit_Dict(self, node):
with self.delimit(b"{", b"}"):
for k, v in zip(node.keys, node.values):
if k is None:
self.update(b"**")
self.hash_expr(v)
else:
self.hash_expr(k)
self.update(b": ")
self.hash_expr(v)
self.update(b", ")
def interleave(self, inter, f, seq):
"""Call f on each item in seq, calling inter() in between."""
seq = iter(seq)
try:
f(next(seq))
except StopIteration:
pass
else:
for x in seq:
inter()
f(x)
def visit_Import(self, node):
self.fill(b"import ")
self.interleave(lambda: self.update(b", "), self.traverse, node.names)
def visit_ImportFrom(self, node):
self.fill(b"from ")
self.update(b"." * (node.level or 0))
if node.module:
self.update(node.module.encode())
def visit_Set(self, node):
if node.elts:
with self.delimit(b"{", b"}"):
self.interleave(
lambda: self.update(b", "), self.traverse, node.elts
)
else:
self.update(b"{*()}")
def visit_ListComp(self, node):
with self.delimit(b"[", b"]"):
self.traverse(node.elt)
for gen in node.generators:
self.traverse(gen)
def visit_GeneratorExp(self, node):
with self.delimit(b"(", b")"):
self.traverse(node.elt)
for gen in node.generators:
self.traverse(gen)
def visit_SetComp(self, node):
with self.delimit(b"{", b"}"):
self.traverse(node.elt)
for gen in node.generators:
self.traverse(gen)
def visit_DictComp(self, node):
with self.delimit(b"{", b"}"):
self.traverse(node.key)
self.update(b": ")
self.traverse(node.value)
for gen in node.generators:
self.traverse(gen)
def visit_Call(self, node):
self.traverse(node.func)
with self.delimit(b"(", b")"):
self.interleave(
lambda: self.update(b", "), self.traverse, node.args
)
self.interleave(
lambda: self.update(b", "), self.traverse, node.keywords
)
def visit_Subscript(self, node):
self.traverse(node.value)
with self.delimit(b"[", b"]"):
if isinstance(node.slice, ast.Tuple) and node.slice.elts:
self.items_view(self.traverse, node.slice.elts)
else:
self.traverse(node.slice)
def items_view(
self, traverser: Callable[[ast.AST], None], items: list[ast.AST]
):
if len(items) == 1:
traverser(items[0])
self.update(b",")
else:
self.interleave(lambda: self.update(b", "), traverser, items)
def visit_Starred(self, node):
self.update(b"*")
self.traverse(node.value)
def visit_Ellipsis(self, node):
self.update(b"...")
def visit_Slice(self, node):
if node.lower:
self.traverse(node.lower)
self.update(b":")
if node.upper:
self.traverse(node.upper)
if node.step:
self.update(b":")
self.traverse(node.step)
def visit_Match(self, node):
self.fill(b"match ", allow_semicolon=False)
self.traverse(node.subject)
with self.block():
for case in node.cases:
self.traverse(case)
def visit_arg(self, node):
self.update(node.arg.encode())
def visit_arguments(self, node):
first = True
all_args = node.posonlyargs + node.args
defaults = [None] * (
len(all_args) - len(node.defaults)
) + node.defaults
for i, (a, d) in enumerate(zip(all_args, defaults), 1):
if first:
first = False
else:
self.update(b", ")
self.traverse(a)
if d:
self.update(b"=")
self.traverse(d)
if i == len(node.posonlyargs):
self.update(b", /")
if node.vararg or node.kwonlyargs:
if first:
first = False
else:
self.update(b", ")
self.update(b"*")
if node.vararg:
self.update(node.vararg.arg.encode())
if node.kwonlyargs:
for a, d in zip(node.kwonlyargs, node.kw_defaults):
self.update(b", ")
self.traverse(a)
if d:
self.update(b"=")
self.traverse(d)
if node.kwarg:
if first:
first = False
else:
self.update(b", ")
self.update(b"**" + node.kwarg.arg.encode())
def visit_keyword(self, node):
if node.arg is None:
self.update(b"**")
else:
self.update(node.arg.encode())
self.update(b"=")
self.traverse(node.value)
def visit_alias(self, node):
self.update(node.name.encode())
if node.asname:
self.update(b" as " + node.asname.encode())
def visit_withitem(self, node):
self.traverse(node.context_expr)
if node.optional_vars:
self.update(b" as ")
self.traverse(node.optional_vars)
def visit_match_case(self, node):
self.fill(b"case ", allow_semicolon=False)
self.traverse(node.pattern)
if node.guard:
self.update(b" if ")
self.traverse(node.guard)
with self.block():
self.traverse(node.body)
def visit_MatchValue(self, node):
self.traverse(node.value)
def _write_constant(self, value):
self.update(repr(value).encode())
def visit_MatchSingleton(self, node):
self._write_constant(node.value)
def visit_MatchSequence(self, node):
with self.delimit(b"[", b"]"):
self.interleave(
lambda: self.update(b", "), self.traverse, node.patterns
)
def visit_MatchStar(self, node):
name = node.name
if name is None:
name = b"_"
self.update(b"*")
self.update(name.encode())
def visit_MatchMapping(self, node):
def upd_key_pattern_pair(pair):
k, p = pair
self.traverse(k)
self.update(b": ")
self.traverse(p)
with self.delimit(b"{", b"}"):
keys = node.keys
self.interleave(
lambda: self.update(b", "),
upd_key_pattern_pair,
zip(keys, node.patterns, strict=True),
)
rest = node.rest
if rest is not None:
if keys:
self.update(b", ")
self.update(b"**")
self.update(rest.encode())
def visit_MatchClass(self, node):
self.traverse(node.cls)
with self.delimit(b"(", b")"):
pats = node.patterns
self.interleave(lambda: self.update(b", "), self.traverse, pats)
attrs = node.kwd_attrs
if attrs:
def upd_attr_pat(pair):
attr, pat = pair
self.update(attr.encode())
self.update(b"=")
self.traverse(pat)
if pats:
self.update(b", ")
self.interleave(
lambda: self.update(b", "),
upd_attr_pat,
zip(attrs, node.kwd_patterns, strict=True),
)
def visit_MatchAs(self, node):
name = node.name
pattern = node.pattern
if name is None:
self.update(b"_")
elif pattern is None:
self.update(node.name.encode())
else:
with self.parens():
self.traverse(pattern)
self.update(b" as ")
self.update(node.name.encode())
def visit_MatchOr(self, node):
with self.parens():
self.interleave(
lambda: self.update(b" | "), self.traverse, node.patterns
)
def visit_FunctionType(self, node):
"""Since this is a function signature, rather than a real function, ignore it
The purpose of these hashes is to tell if you have the code needed to load
a game. Empty signatures don't matter.
"""
pass
def visit_Expr(self, node):
self.fill()
self.traverse(node.value)
def visit_Assign(self, node):
self.fill()
for target in node.targets:
self.traverse(target)
self.update(b" = ")
self.traverse(node.value)
def visit_AugAssign(self, node):
self.fill()
self.traverse(node.target)
self.update(b" ")
self.update(self.binop[node.op.__class__.__name__])
self.update(b"= ")
self.traverse(node.value)
def visit_AnnAssign(self, node):
if node.value is None: # rather than ast.Constant *representing* None
return
self.fill()
with self.delimit(b"(", b")"):
self.traverse(node.target)
self.update(b" = ")
self.traverse(node.value)
def visit_Return(self, node):
self.fill(b"return")
if node.value:
self.update(b" ")
self.traverse(node.value)
def visit_Pass(self, node):
return
def visit_Break(self, node):
self.fill(b"break")
def visit_Continue(self, node):
self.fill(b"continue")
def visit_Delete(self, node):
self.fill(b"del ")
self.interleave(
lambda: self.update(b", "), self.traverse, node.targets
)
def visit_Assert(self, node):
return
def visit_Global(self, node):
self.fill(b"global ")
self.interleave(lambda: self.update(b", "), self.write, node.names)
def visit_Nonlocal(self, node):
self.fill(b"nonlocal ")
self.interleave(lambda: self.update(b", "), self.write, node.names)
def write(self, s: str) -> None:
self.update(s.encode())
def visit_Await(self, node):
raise TypeError("Lisien isn't an event loop")
def visit_Yield(self, node):
with self.parens():
self.update(b"yield")
if node.value:
self.update(b" ")
self.traverse(node.value)
def visit_YieldFrom(self, node):
with self.parens():
self.update(b"yield from ")
self.traverse(node.value)
def visit_Raise(self, node):
self.fill(b"raise")
if not node.exc:
return
self.update(b" ")
self.traverse(node.exc)
if node.cause:
self.update(b" from ")
self.traverse(node.cause)
def _visit_try(self, node):
self.fill(b"try", allow_semicolon=False)
with self.block():
self.traverse(node.body)
for ex in node.handlers:
self.traverse(ex)
if node.orelse:
self.fill(b"else", allow_semicolon=False)
if node.finalbody:
self.fill(b"finally", allow_semicolon=False)
with self.block():
self.traverse(node.finalbody)
def visit_Try(self, node):
prev_in_try_star = self._in_try_star
try:
self._in_try_star = False
self._visit_try(node)
finally:
self._in_try_star = prev_in_try_star
def visit_TryStar(self, node):
prev_in_try_star = self._in_try_star
try:
self._in_try_star = True
self._visit_try(node)
finally:
self._in_try_star = prev_in_try_star
def visit_ExceptHandler(self, node):
self.fill(
b"except*" if self._in_try_star else b"except",
allow_semicolon=False,
)
if node.type:
self.update(b" ")
self.traverse(node.type)
if node.name:
self.update(b" as ")
self.write(node.name)
with self.block():
self.traverse(node.body)
def visit_ClassDef(self, node):
self.maybe_newline()
for deco in node.decorator_list:
self.fill(b"@", allow_semicolon=False)
self.traverse(deco)
self.fill(b"class " + node.name.encode(), allow_semicolon=False)
# we don't care about type parameters
with self.delimit_if(
b"(", b")", condition=node.bases or node.keywords
):
self.interleave(
lambda: self.update(b", "), self.traverse, node.bases
)
self.interleave(
lambda: self.update(b", "), self.traverse, node.keywords
)
with self.block():
# we don't care about docstrings
self.traverse(node.body)
def visit_For(self, node):
self.fill(b"for ", allow_semicolon=False)
self.traverse(node.target)
self.update(b" in ")
self.traverse(node.iter)
with self.block(): # we don't care about type comments
self.traverse(node.body)
if node.orelse:
self.fill(b"else", allow_semicolon=False)
with self.block():
self.traverse(node.orelse)
def visit_AsyncFor(self, node):
raise TypeError("Lisien isn't an event loop")
def visit_While(self, node):
self.fill(b"while ", allow_semicolon=False)
self.traverse(node.test)
with self.block():
self.traverse(node.body)
if node.orelse:
self.fill(b"else", allow_semicolon=False)
with self.block():
self.traverse(node.orelse)
def visit_With(self, node):
self.fill(b"with ", allow_semicolon=False)
self.interleave(lambda: self.update(b", "), self.traverse, node.items)
with self.block(): # we don't care about type comments
self.traverse(node.body)
def visit_AsyncWith(self, node):
raise TypeError("Lisien isn't an event loop")
def _write_ftstring_inner(self, node, is_format_spec=False):
if isinstance(node, ast.JoinedStr):
for value in node.values:
self._write_ftstring_inner(
value, is_format_spec=is_format_spec
)
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
value = node.value.replace("{", "{{").replace("}", "}}")
if is_format_spec:
value = value.replace("\\", "\\\\")
value = value.replace("'", "\\'")
value = value.replace('"', '\\"')
value = value.replace("\n", "\\n")
self.write(value)
elif isinstance(node, ast.FormattedValue):
self.visit_FormattedValue(node)
elif isinstance(node, ast.Interpolation):
self.visit_Interpolation(node)
else:
raise ValueError("Unexpected node inside JoinedStr", node)
def _write_ftstring(self, values, prefix):
self.write(prefix)
for value in values:
self._write_ftstring_inner(value)
def visit_JoinedStr(self, node):
self._write_ftstring(node.values, "f")
def visit_TemplateStr(self, node):
self._write_ftstring(node.values, "t")
def _write_interpolation(self, node, is_interpolation=False):
with self.delimit(b"{", b"}"):
if is_interpolation:
expr = node.str
else:
self.visit(node.value)
return
self.write(expr)
if node.conversion != -1:
self.update(b"!")
self.update(node.conversion.to_bytes())
if node.format_spec:
self.update(b":")
self._write_fstring_inner(
node.format_spec, is_format_spec=True
)
def visit_FormattedValue(self, node):
self._write_interpolation(node)
def visit_Interpolation(self, node):
self._write_interpolation(node, is_interpolation=True)
def visit_List(self, node):
with self.delimit(b"[", b"]"):
self.interleave(
lambda: self.update(b", "), self.traverse, node.elts
)
def visit_TypeVar(self, node):
return
def visit_TypeVarTuple(self, node):
return
def visit_ParamSpec(self, node):
return
def visit_TypeAlias(self, node):
return
def visit_Tuple(self, node):
with self.delimit(b"(", b")"):
self.items_view(self.traverse, node.elts)
def visit_Attribute(self, node):
self.traverse(node.value)
if isinstance(node.value, ast.Constant) and isinstance(
node.value.value, int
):
self.update(b" ")
self.update(b".")
self.write(node.attr)
def digest(self) -> bytes:
return self._blake2b.digest()
def hexdigest(self) -> str:
return self._blake2b.hexdigest()
def b64digest(self) -> str:
digest = self.digest()
encoded = base64.standard_b64encode(digest)
return encoded.decode()
[docs]
class FunctionStore[_K: str, _T: FunctionType | MethodType](
AbstractFunctionStore[_K, _T], Signal
):
"""A module-like object that lets you alter its code and save your changes.
Instantiate it with a path to a file that you want to keep the code in.
Assign functions to its attributes, then call its ``save()`` method,
and they'll be unparsed and written to the file.
This is a ``Signal``, so you can pass a function to its ``connect`` method,
and it will be called when a function is added, changed, or deleted.
The keyword arguments will be ``attr``, the name of the function, and ``val``,
the function itself.
"""
_filename: str | None
def __init__(
self, filename: str | None, initial: dict = None, module: str = None
):
if initial is None:
initial = {}
super().__init__()
if filename is None:
self._filename = None
self._module = self.__name__ = module
self._ast = ast.Module(body=[], type_ignores=[])
self._ast_idx = {}
self._need_save = False
self._locl = initial
else:
if not filename.endswith(".py"):
raise ValueError(
"FunctionStore can only work with pure Python source code"
)
self._filename = os.path.abspath(os.path.realpath(filename))
self._store = os.path.basename(self._filename).removesuffix(".py")
try:
self.reimport()
except (FileNotFoundError, ModuleNotFoundError):
self._module = module
self._ast = ast.Module(body=[], type_ignores=[])
self._ast_idx = {}
self.save()
self._need_save = False
self._locl = {}
for k, v in initial.items():
setattr(self, k, v)
def __dir__(self):
yield from self._locl
yield from super().__dir__()
def __getattr__(self, k):
if k in self._locl:
return self._locl[k]
elif self._need_save:
self.save()
return getattr(self._module, k)
elif hasattr(self._module, k):
return getattr(self._module, k)
else:
raise AttributeError("No attribute ", k)
def __setattr__(self, k, v):
if not callable(v):
super().__setattr__(k, v)
return
self._set_source(k, getsource(v), func=v)
def _set_source(self, k: str, source: str, func: Callable | None = None):
if func is None:
holder = {}
exec(source, holder)
if k not in holder:
raise NameError(
"Function in source has a different name", k, source
)
func = holder[k]
outdented = dedent_source(source)
expr = ast.parse(outdented)
expr.body[0].name = k
if k in self._ast_idx:
self._ast.body[self._ast_idx[k]] = expr
else:
self._ast_idx[k] = len(self._ast.body)
self._ast.body.append(expr)
if self._filename is not None:
self._need_save = True
if isinstance(self._module, str):
func.__module__ = self._module
self._locl[k] = func
self.send(self, attr=k, val=func)
def __call__(self, v):
if isinstance(self._module, str):
v.__module__ = self._module
elif hasattr(self._module, "__name__"):
v.__module__ = self._module.__name__
setattr(self, v.__name__, v)
return v
def __delattr__(self, k):
del self._locl[k]
del self._ast.body[self._ast_idx[k]]
del self._ast_idx[k]
for name in list(self._ast_idx):
if name > k:
self._ast_idx[name] -= 1
if self._filename is not None:
self._need_save = True
self.send(self, attr=k, val=None)
def save(self, reimport=True):
if self._filename is None:
return
with open(self._filename, "w", encoding="utf-8") as outf:
outf.write(ast.unparse(self._ast))
self._need_save = False
if reimport:
self.reimport()
def reimport(self, signal: bool = True):
if self._filename is None:
return
path, filename = os.path.split(self._filename)
modname = filename[:-3]
if modname in sys.modules:
del sys.modules[modname]
modname = filename[:-3]
spec = importlib.util.spec_from_file_location(modname, self._filename)
self._module = importlib.util.module_from_spec(spec)
sys.modules[modname] = self._module
spec.loader.exec_module(self._module)
self._ast = ast.parse(self._module.__loader__.get_data(self._filename))
self._ast_idx = {}
for i, node in enumerate(self._ast.body):
if hasattr(node, "name"):
self._ast_idx[node.name] = i
elif hasattr(node, "__name__"):
self._ast_idx[node.__name__] = i
if signal:
self.send(self, attr=None, val=None)
def iterplain(self):
for name, idx in self._ast_idx.items():
yield name, ast.unparse(self._ast.body[idx])
def store_source(self, v: str, name: str | None = None) -> None:
self._need_save = True
outdented = dedent_source(v)
mod = ast.parse(outdented)
expr = ast.Expr(mod)
if len(expr.value.body) != 1:
raise ValueError("Tried to store more than one function")
if name is None:
name = expr.value.body[0].name
else:
expr.value.body[0].name = name
if name in self._ast_idx:
self._ast.body[self._ast_idx[name]] = expr
else:
self._ast_idx[name] = len(self._ast.body)
self._ast.body.append(expr)
locl = {}
exec(compile(mod, self._filename or "", "exec"), {}, locl)
self._locl.update(locl)
self.send(self, attr=name, val=locl[name])
def get_source(self, name: str) -> str:
return ast.unparse(self._ast.body[self._ast_idx[name]])
def blake2b(self) -> str:
"""Return the blake2b hash digest of the code stored here
Neither formatting nor type annotations are considered significant for
the purposes of this hash.
The hash is returned in a base64 string.
"""
hasher = CodeHasher()
todo = dict(self._ast_idx)
# stripped_ast = deepcopy(self._ast.body)
# astor.strip_tree(stripped_ast)
for k in sort_set(todo.keys()):
funcdef = self._ast.body[todo[k]]
if not isinstance(funcdef, ast.FunctionDef):
raise TypeError(
"Only store function defs in FunctionStore", funcdef
)
hasher.visit(funcdef)
return hasher.b64digest()
def __getstate__(self):
return self._locl, self._ast, self._ast_idx
def __setstate__(self, state):
self._locl, self._ast, self._ast_idx = state
class TriggerStore(FunctionStore[TriggerFuncName, TriggerFunc]):
def get_source(self, name: str) -> str:
if name == "truth":
return "def truth(*args):\n\treturn True"
return super().get_source(name)
@staticmethod
def truth(*args):
return True
class PrereqStore(FunctionStore[PrereqFuncName, PrereqFunc]): ...
class ActionStore(FunctionStore[ActionFuncName, ActionFunc]): ...
class UniversalMapping(MutableMapping, Signal):
"""Mapping for variables that are global but which I keep history for"""
__slots__ = ["engine"]
def __init__(self, engine):
"""Store the engine and initialize my private dictionary of
listeners.
"""
super().__init__()
self.engine = engine
def __iter__(self):
return self.engine._universal_cache.iter_keys(*self.engine.time)
def __len__(self):
return self.engine._universal_cache.count_keys(*self.engine.time)
def __getitem__(self, k: KeyHint | UniversalKey):
"""Get the current value of this key"""
return wrapval(
self,
k,
self._get_cache_now(k),
)
def _get_cache_now(self, k: UniversalKey):
return self.engine._universal_cache.retrieve(k, *self.engine.time)
def __setitem__(self, k: KeyHint | UniversalKey, v: ValueHint | Value):
"""Set k=v at the current branch and tick"""
try:
if v == self._get_cache_now(k):
return
except KeyError:
pass
branch, turn, tick = self.engine._nbtt()
self.engine._universal_cache.store(k, branch, turn, tick, v)
self.engine.db.universal_set(k, branch, turn, tick, v)
self.send(self, key=k, val=v)
def _set_cache_now(self, k: UniversalKey, v: Value):
self.engine._universal_cache.store(k, *self.engine.time, v)
def __delitem__(self, k: KeyHint | UniversalKey):
"""Unset this key for the present (branch, tick)"""
branch, turn, tick = self.engine._nbtt()
self.engine._universal_cache.store(k, branch, turn, tick, ...)
self.engine.db.universal_del(k, branch, turn, tick)
self.send(self, key=k, val=...)
class CharacterMapping(MutableMapping, Signal):
"""A mapping by which to access :class:`Character` objects.
If a character already exists, you can always get its name here to
get the :class:`Character` object. Deleting an item here will
delete the character from the world, even if there are still
:class:`Character` objects referring to it; those won't do
anything useful anymore.
"""
engine = getatt("orm")
def __init__(self, orm):
self.orm = orm
Signal.__init__(self)
def __iter__(self):
branch, turn, tick = self.engine.time
return self.engine._graph_cache.iter_keys(branch, turn, tick)
def __len__(self):
branch, turn, tick = self.engine.time
return self.engine._graph_cache.count_keys(branch, turn, tick)
def __contains__(self, item: KeyHint | CharName) -> bool:
branch, turn, tick = self.engine.time
try:
return (
self.engine._graph_cache.retrieve(item, branch, turn, tick)
== "DiGraph"
)
except KeyError:
return False
def __getitem__(self, name: KeyHint | CharName) -> Character:
"""Return the named character, if it's been created.
Try to use the cache if possible.
"""
from .character import Character
name = CharName(name)
if name not in self:
raise KeyError("No such character", name)
cache = self.engine._graph_objs
if name not in cache:
cache[name] = Character(
self.engine, name, init_rulebooks=name not in self
)
ret = cache[name]
if not isinstance(ret, Character):
raise TypeError(
"You put something weird in the Character cache", type(ret)
)
return ret
def __setitem__(
self,
name: KeyHint | CharName,
value: dict[KeyHint | Stat, ValueHint | Value] | nx.Graph,
):
"""Make a new character by the given name, and initialize its data to
the given value.
"""
self.engine._init_graph(name, "DiGraph", value)
self.send(self, key=name, val=self.engine.character[name])
def __delitem__(self, name: KeyHint | CharName):
self.engine.del_character(name)
self.send(self, key=name, val=None)
_K = TypeVar("_K")
_V = TypeVar("_V")
class CompositeDict[_K, _V](MutableMapping[_K, _V], Signal):
"""Combine two dictionaries into one"""
def __init__(self, d1, d2):
"""Store dictionaries"""
super().__init__()
self.d1 = d1
self.d2 = d2
def __iter__(self):
"""Iterate over both dictionaries' keys"""
for k in self.d1:
yield k
for k in self.d2:
yield k
def __len__(self):
"""Sum the lengths of both dictionaries"""
return len(self.d1) + len(self.d2)
def __contains__(self, item):
return item in self.d1 or item in self.d2
def __getitem__(self, k):
"""Get an item from ``d1`` if possible, then ``d2``"""
try:
return self.d1[k]
except KeyError:
return self.d2[k]
def __setitem__(self, key, value):
self.d1[key] = value
self.send(self, key=key, value=value)
def __delitem__(self, key):
deleted = False
if key in self.d2:
deleted = True
del self.d2[key]
if key in self.d1:
deleted = True
del self.d1[key]
if not deleted:
raise KeyError("{} is in neither of my wrapped dicts".format(key))
self.send(self, key=key, value=None)
def patch(self, d):
"""Recursive update"""
for k, v in d.items():
if k in self:
self[k].update(v)
else:
self[k] = deepcopy(v)