api/backend/python/xl-client/dial_xl/overrides.py (302 lines of code) (raw):
from typing import Iterable
from dial_xl.events import (
Event,
ObservableNode,
ObservableObserver,
notify_observer,
)
from dial_xl.reader import _Reader
from dial_xl.utils import _escape_field_name, _unescape_field_name
class _OverrideHeader:
__before: str = ""
__name: str
__after: str = "\n"
def __init__(self, name: str):
self.__name = name
@property
def before(self) -> str:
return self.__before
@property
def name(self) -> str:
return self.__name
@name.setter
def name(self, value: str):
self.__name = value
@property
def after(self) -> str:
return self.__after
@after.setter
def after(self, value: str):
self.__after = value
def to_dsl(self) -> str:
"""Converts the override to DSL format."""
return f"{self.__before}{self.__name}{self.__after}"
@classmethod
def _deserialize(cls, reader: _Reader) -> "_OverrideHeader":
result = cls("")
result.__before = reader.next(lambda d: d["span"]["from"])
result.__name = reader.next(lambda d: d["span"]["to"])
result.__after = reader.till_linebreak()
return result
class _Override:
__before: str = ""
__override: str
__after: str = "\n"
def __init__(self, value: str):
self.__value = value
@property
def override(self) -> str:
return self.__value
@override.setter
def override(self, value: str):
self.__value = value
@property
def after(self) -> str:
return self.__after
@after.setter
def after(self, value: str):
self.__after = value
def to_dsl(self) -> str:
"""Converts the override to DSL format."""
return f"{self.__before}{self.__value}{self.__after}"
@classmethod
def _deserialize(cls, reader: _Reader) -> "_Override":
result = cls("")
result.__before = reader.next(lambda d: d["span"]["from"])
result.__value = reader.next(lambda d: d["span"]["to"])
result.__after = reader.till_linebreak()
return result
class _OverrideLine:
__overrides: list[_Override]
def __init__(self, overrides: list[_Override]):
self.__overrides = overrides
@property
def overrides(self) -> list[_Override]:
return self.__overrides
def to_dsl(self) -> str:
"""Converts the override to DSL format."""
return f"{''.join(override.to_dsl() for override in self.__overrides)}"
@classmethod
def _deserialize(cls, reader: _Reader) -> "_OverrideLine":
result = cls([])
for index, value_entity in enumerate(reader.entity):
value_reader = reader.with_entity(value_entity)
value = _Override._deserialize(value_reader)
result.__overrides.append(value)
reader.position = value_reader.position
return result
class _Overrides:
__before: str = ""
__prefix: str = "override\n"
__headers: list[_OverrideHeader]
__lines: list[_OverrideLine]
def __init__(self):
self.__headers = []
self.__lines = []
def to_dsl(self) -> str:
"""Converts the override to DSL format."""
return (
f"{self.__before}"
f"{self.__prefix}"
f"{''.join(header.to_dsl() for header in self.__headers)}"
f"{''.join(line.to_dsl() for line in self.__lines)}"
)
@property
def headers(self) -> list[_OverrideHeader]:
return self.__headers
@property
def lines(self) -> list[_OverrideLine]:
return self.__lines
@classmethod
def _deserialize(cls, reader: _Reader) -> "_Overrides":
result = cls()
result.__before = reader.next(lambda d: d["span"]["from"])
result.__prefix = reader.before_next()
for index, header_entity in enumerate(reader.entity.get("headers", [])):
header_reader = reader.with_entity(header_entity)
header = _OverrideHeader._deserialize(header_reader)
result.__headers.append(header)
reader.position = header_reader.position
for index, line_entity in enumerate(reader.entity.get("values", [])):
line_reader = reader.with_entity(line_entity)
line = _OverrideLine._deserialize(line_reader)
result.__lines.append(line)
reader.position = line_reader.position
return result
class Override(ObservableNode):
__values: dict[str, str]
__row_number: str | None = None
def __init__(
self,
values: dict[str, str] | None = None,
row_number: str | None = None,
):
self.__values = values or {}
self.__row_number = row_number
@property
def names(self) -> Iterable[str]:
return self.__values.keys()
@property
def row_number(self) -> str | None:
return self.__row_number
@row_number.setter
@notify_observer
def row_number(self, value: str | None):
"""Set the row number of the override and invalidates compilation/computation results and sheet parsing errors"""
self.__row_number = value
def __getitem__(self, key: str) -> str:
return self.__values.get(key)
@notify_observer
def __setitem__(self, key: str, value: str):
self.__values[key] = value
@notify_observer
def __delitem__(self, key: str):
del self.__values[key]
class Overrides(ObservableObserver):
__overrides: _Overrides
__row_position: int | None = None
__field_names: list[str]
__lines: list[Override]
def __init__(self):
self.__overrides = _Overrides()
self.__field_names = []
self.__lines = []
@property
def field_names(self) -> Iterable[str]:
return (
_unescape_field_name(name)
for name in self.__field_names
if name != "row"
)
@property
def row_position(self) -> int | None:
return self.__row_position
def __len__(self):
return len(self.__lines)
def __getitem__(self, index: int) -> Override:
return self.__lines[index]
@notify_observer
def __setitem__(self, index: int, value: Override):
if index >= len(self.__lines):
raise IndexError("Override line index out of range")
old = self.__lines[index]
old._detach()
self.__overrides.lines[index] = self._attach_override_line(value)
self.__lines[index] = value
@notify_observer
def __delitem__(self, index: int):
self.__overrides.lines.pop(index)
override = self.__lines.pop(index)
override._detach()
@notify_observer
def append(self, value: Override):
self.__overrides.lines.append(self._attach_override_line(value))
self.__lines.append(value)
def _attach_override_line(self, value) -> _OverrideLine:
if value.row_number is not None:
self._add_empty_if_missing("row")
for name in value.names:
self._add_empty_if_missing(_escape_field_name(name))
values: list[_Override] = []
for index, name in enumerate(self.__field_names):
if index > 0:
values[index - 1].after = ","
if _unescape_field_name(name) in value.names:
values.append(_Override(value[_unescape_field_name(name)]))
elif name == "row":
values.append(_Override(value.row_number or ""))
else:
values.append(_Override(""))
value._attach(self)
return _OverrideLine(values)
def _add_empty_if_missing(self, name):
if name not in self.__field_names:
if name == "row":
self.__row_position = len(self.__field_names)
self.__field_names.append(name)
if len(self.__overrides.headers) > 0:
self.__overrides.headers[-1].after = ","
self.__overrides.headers.append(_OverrideHeader(name))
for line in self.__overrides.lines:
if len(line.overrides) > 0:
line.overrides[-1].after = ","
line.overrides.append(_Override(""))
def to_dsl(self) -> str:
"""Converts the manual overrides to DSL format."""
return self.__overrides.to_dsl()
def _notify_before(self, event: Event):
if self._observer:
self._observer._notify_before(event)
sender = event.sender
if isinstance(sender, Override):
if event.method_name == "__setitem__":
self._on_override_update(
self.__lines.index(sender),
_escape_field_name(event.kwargs["key"]),
event.kwargs["value"],
)
elif event.method_name == "__delitem__":
self._on_override_remove(
self.__lines.index(sender),
_escape_field_name(event.kwargs["key"]),
)
elif event.method_name == "row_number":
if event.kwargs["value"] is None:
self._on_override_remove(self.__lines.index(sender), "row")
else:
self._on_override_update(
self.__lines.index(sender), "row", event.kwargs["value"]
)
def _set_overrides(self, overrides: _Overrides):
self.__field_names = [header.name for header in overrides.headers]
self.__row_position = next(
(
index
for index, header in enumerate(overrides.headers)
if header.name == "row"
),
None,
)
self.__lines = [
Override(
values={
_unescape_field_name(
self.__field_names[index]
): override.override
for index, override in enumerate(line.overrides)
if self.__row_position is None
or index != self.__row_position
},
row_number=(
line.overrides[self.__row_position].override
if self.__row_position is not None
else None
),
)
for line in overrides.lines
]
for line in self.__lines:
line._attach(self)
self.__overrides = overrides
def _on_override_update(self, index: int, name: str, value: str):
self._add_empty_if_missing(name)
position = self.__field_names.index(name)
self.__overrides.lines[index].overrides[position].override = value
def _on_override_remove(self, index: int, name: str):
position = self.__field_names.index(name)
self.__overrides.lines[index].overrides[position].override = ""
last = True
for line in self.__overrides.lines:
if line.overrides[position].override != "":
last = False
break
if last:
removed_header = self.__overrides.headers.pop(position)
if position > 0:
self.__overrides.headers[
position - 1
].after = removed_header.after
for line in self.__overrides.lines:
removed_value = line.overrides.pop(position)
if position > 0:
line.overrides[position - 1].after = removed_value.after
self.__field_names.pop(position)
@classmethod
def _deserialize(cls, reader: _Reader) -> "Overrides":
result = cls()
result._set_overrides(_Overrides._deserialize(reader))
return result