diff --git a/.github/workflows/typing-check.yml b/.github/workflows/typing-check.yml index 1ac277103..4b53634b4 100644 --- a/.github/workflows/typing-check.yml +++ b/.github/workflows/typing-check.yml @@ -16,7 +16,7 @@ jobs: max-parallel: 3 matrix: # add packages to check typing - package-name: ["geos-geomechanics", "geos-posp", "geos-timehistory", "geos-utils", "geos-xml-tools", "hdf5-wrapper"] + package-name: ["geos-geomechanics", "geos-posp", "geos-timehistory", "geos-utils", "geos-trame", "geos-xml-tools", "hdf5-wrapper"] steps: - uses: actions/checkout@v4 @@ -30,7 +30,7 @@ jobs: # working-directory: ./${{ matrix.package-name }} run: | python -m pip install --upgrade pip - python -m pip install mypy ruff + python -m pip install mypy ruff types-PyYAML - name: Typing check with mypy # working-directory: ./${{ matrix.package-name }} diff --git a/geos-trame/.pre-commit-config.yaml b/geos-trame/.pre-commit-config.yaml index dc0f93736..406666fee 100644 --- a/geos-trame/.pre-commit-config.yaml +++ b/geos-trame/.pre-commit-config.yaml @@ -1,17 +1,23 @@ repos: - - repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - exclude: ^.*\b(schema_generated)\b.*$ - entry: black --check --force-exclude - - repo: https://github.com/codespell-project/codespell rev: v2.1.0 hooks: - id: codespell - - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.12 + hooks: + - id: ruff + args: ["--config", "./.ruff.toml"] + + - repo: https://github.com/google/yapf + rev: v0.43.0 + hooks: + - id: yapf + args: ["-ir", "--style", "./.style.yapf"] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.16.0 hooks: - - id: flake8 + - id: mypy + additional_dependencies: [types-PyYAML] diff --git a/geos-trame/CONTRIBUTING.rst b/geos-trame/CONTRIBUTING.rst deleted file mode 100644 index 6d5a9c236..000000000 --- a/geos-trame/CONTRIBUTING.rst +++ /dev/null @@ -1,19 +0,0 @@ -========================== -Contributing to geos-trame -========================== - -#. Clone the repository using ``git clone`` -#. Install pre-commit via ``pip install pre-commit`` -#. Run ``pre-commit install`` to set up pre-commit hooks -#. Make changes to the code, and commit your changes to a separate branch -#. Create a fork of the repository on GitHub -#. Push your branch to your fork, and open a pull request - -Tips -#### - -#. When first creating a new project, it is helpful to run ``pre-commit run --all-files`` to ensure all files pass the pre-commit checks. -#. A quick way to fix ``black`` issues is by installing black (``pip install black``) and running the ``black`` command at the root of your repository. -#. Sometimes, ``black`` and ``flake8`` do not agree. Add options to your ``.flake8`` file to fix these things. See the `flake8 configuration docs `_ for more details. -#. A quick way to fix ``codespell`` issues is by installing codespell (``pip install codespell``) and running the ``codespell -w`` command at the root of your directory. -#. The `.codespellrc file `_ can be used fix any other codespell issues, such as ignoring certain files, directories, words, or regular expressions. diff --git a/geos-trame/README.rst b/geos-trame/README.rst index 797f1948f..a1c54deb2 100644 --- a/geos-trame/README.rst +++ b/geos-trame/README.rst @@ -47,6 +47,15 @@ To be able to run the test suite, make sure to install the additionals dependenc Then you can run the test with `pytest .` +Optional +-------- + +To use pre-commit hooks (ruff, mypy, yapf,...), make sure to install the dev dependencies: + +.. code-block:: console + + pip install -e '.[dev]' + Regarding GEOS -------------- @@ -54,7 +63,7 @@ This application takes an XML file from the GEOS project to load dynamically all To be able to do that, we need first to generate the corresponding python class based on a xsd schema provided by GEOS. -`For more details `_ +`For more details `_ Features -------- diff --git a/geos-trame/pyproject.toml b/geos-trame/pyproject.toml index 1031aa23a..44b5cec6b 100644 --- a/geos-trame/pyproject.toml +++ b/geos-trame/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "funcy==2.0", "typing_inspect==0.9.0", "typing_extensions>=4.12", + "PyYAML", ] [project.optional-dependencies] @@ -56,14 +57,16 @@ build = [ dev = [ "pylint", "mypy", - "black", - "isort" + "types-PyYAML", + "ruff", + "pre-commit" ] test = [ "pytest==8.3.3", "pytest-seleniumbase==4.31.6", "pixelmatch==0.3.0", "Pillow==11.0.0", + "pytest-mypy==0.10.3", "pytest-xprocess==1.0.2" ] @@ -72,10 +75,10 @@ file = "README.md" content-type = "text/markdown" [project.scripts] -geos-trame = "geos_trame.app.__main__:main" +geos-trame = "geos.trame.app.main:main" [project.entry-points.jupyter_serverproxy_servers] -geos-trame = "geos_trame.app.jupyter:jupyter_proxy_info" +geos-trame = "geos.trame.app.jupyter:jupyter_proxy_info" [tool.setuptools] license-files = ["LICENSE"] @@ -116,18 +119,3 @@ disable = [ "R0913", # (too-many-arguments) "W0105", # (pointless-string-statement) ] - -[tool.black] -line-length = 88 -target-version = ['py310'] -include = '\.pyi?$' -extend-exclude = ''' -src/geos_trame/schema_generated/*.py -''' - -[tool.isort] -profile = "black" -src_paths = ["src", "tests"] -blackArgs = ["--preview"] -py_version = 310 - diff --git a/geos-trame/src/geos_trame/__init__.py b/geos-trame/src/geos/trame/__init__.py similarity index 100% rename from geos-trame/src/geos_trame/__init__.py rename to geos-trame/src/geos/trame/__init__.py diff --git a/geos-trame/src/geos_trame/app/__init__.py b/geos-trame/src/geos/trame/app/__init__.py similarity index 100% rename from geos-trame/src/geos_trame/app/__init__.py rename to geos-trame/src/geos/trame/app/__init__.py diff --git a/geos-trame/src/geos_trame/app/io/__init__.py b/geos-trame/src/geos/trame/app/components/__init__.py similarity index 100% rename from geos-trame/src/geos_trame/app/io/__init__.py rename to geos-trame/src/geos/trame/app/components/__init__.py diff --git a/geos-trame/src/geos_trame/app/ui/alertHandler.py b/geos-trame/src/geos/trame/app/components/alertHandler.py similarity index 59% rename from geos-trame/src/geos_trame/app/ui/alertHandler.py rename to geos-trame/src/geos/trame/app/components/alertHandler.py index b3d416468..4ae36030f 100644 --- a/geos-trame/src/geos_trame/app/ui/alertHandler.py +++ b/geos-trame/src/geos/trame/app/components/alertHandler.py @@ -7,14 +7,14 @@ class AlertHandler( vuetify3.VContainer ): - """ - Vuetify component used to display an alert status. + """Vuetify component used to display an alert status. This alert will be displayed in the bottom right corner of the screen. It will be displayed until closed by the user or after 10 seconds if it is a success or warning. """ - def __init__( self ): + def __init__( self ) -> None: + """Constructor.""" super().__init__( fluid=True, classes="pa-0 ma-0", @@ -31,31 +31,32 @@ def __init__( self ): self.generate_alert_ui() - def generate_alert_ui( self ): - """ - Generate the alert UI. + def generate_alert_ui( self ) -> None: + """Generate the alert UI. The alert will be displayed in the bottom right corner of the screen. Use an abritary z-index value to put the alert on top of the other components. """ - with self: - with vuetify3.VCol( style="width: 40%; position: fixed; right: 50px; bottom: 50px; z-index: 100;", ): - vuetify3.VAlert( - style="max-height: 20vh; overflow-y: auto", - classes="ma-2", - v_for=( "(status, index) in alerts", ), - key="status", - type=( "status.type", "info" ), - text=( "status.message", "" ), - title=( "status.title", "" ), - closable=True, - click_close=( self.on_close, f"[status.id]" ), - ) - - def add_alert( self, type: str, title: str, message: str ): - """ - Add a status to the stack with a unique id. + with ( + self, + vuetify3.VCol( style="width: 40%; position: fixed; right: 50px; bottom: 50px; z-index: 100;", ), + ): + vuetify3.VAlert( + style="max-height: 20vh; overflow-y: auto", + classes="ma-2", + v_for=( "(status, index) in alerts", ), + key="status", + type=( "status.type", "info" ), + text=( "status.message", "" ), + title=( "status.title", "" ), + closable=True, + click_close=( self.on_close, "[status.id]" ), + ) + + def add_alert( self, type: str, title: str, message: str ) -> None: + """Add a status to the stack with a unique id. + If there are more than 5 alerts displayed, remove the oldest. A warning will be automatically closed after 10 seconds. """ @@ -77,21 +78,15 @@ def add_alert( self, type: str, title: str, message: str ): if type == "warning": asyncio.get_event_loop().call_later( self.__lifetime_of_alert, self.on_close, alert_id ) - async def add_warning( self, title: str, message: str ): - """ - Add an alert of type "warning" - """ + async def add_warning( self, title: str, message: str ) -> None: + """Add an alert of type 'warning'.""" self.add_alert( "warning", title, message ) - async def add_error( self, title: str, message: str ): - """ - Add an alert of type "error" - """ + async def add_error( self, title: str, message: str ) -> None: + """Add an alert of type 'error'.""" self.add_alert( "error", title, message ) - def on_close( self, alert_id ): - """ - Remove in the state the alert associated to the given id. - """ + def on_close( self, alert_id: int ) -> None: + """Remove in the state the alert associated to the given id.""" self.state.alerts = list( filter( lambda i: i[ "id" ] != alert_id, self.state.alerts ) ) self.state.flush() diff --git a/geos-trame/src/geos/trame/app/components/properties_checker.py b/geos-trame/src/geos/trame/app/components/properties_checker.py new file mode 100644 index 000000000..10157eaf2 --- /dev/null +++ b/geos-trame/src/geos/trame/app/components/properties_checker.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. +# SPDX-FileContributor: Kitware +from typing import Any + +from trame_client.widgets.core import AbstractElement +from trame_simput import get_simput_manager + +from geos.trame.app.data_types.field_status import FieldStatus +from geos.trame.app.data_types.renderable import Renderable +from geos.trame.app.deck.tree import DeckTree +from geos.trame.app.ui.viewer.regionViewer import RegionViewer +from geos.trame.app.utils.geos_utils import group_name_ref_array_to_list + +# Doc reference: https://geosx-geosx.readthedocs-hosted.com/en/latest/docs/sphinx/datastructure/CompleteXMLSchema.html +attributes_to_check = [ ( "region_attribute", str ), ( "fields_to_import", list ), ( "surfacicFieldsToImport", list ) ] + + +class PropertiesChecker( AbstractElement ): + """Validity checker of properties within a deck tree.""" + + def __init__( self, tree: DeckTree, region_viewer: RegionViewer, **kwargs: Any ) -> None: + """Constructor.""" + super().__init__( "div", **kwargs ) + + self.tree = tree + self.region_viewer = region_viewer + self.simput_manager = get_simput_manager( id=self.state.sm_id ) + + def check_fields( self ) -> None: + """Check all the fields in the deck_tree. + + Get the names of all the cell data arrays from the input of the region viewer, then check that + all the attributes in `attributes_to_check` have a value corresponding to one of the array names. + """ + array_names = self._get_array_names() + for field in self.state.deck_tree: + self._check_field( field, array_names ) + self.state.dirty( "deck_tree" ) + self.state.flush() + + def _check_field( self, field: dict, array_names: list[ str ] ) -> None: + """Check that all the attributes in `attributes_to_check` have a value corresponding to one of the array names. + + Set the `valid` property to the result of this check, and if necessary, indicate which properties are invalid. + """ + if len( array_names ) == 0 and Renderable.VTKMESH.value in field[ "id" ]: + self.ctrl.load_vtkmesh_from_id( field[ "id" ] ) + array_names = self._get_array_names() + field[ "drawn" ] = True + field[ "valid" ] = FieldStatus.VALID.value + field[ "invalid_properties" ] = [] + + proxy = self.simput_manager.proxymanager.get( field[ "id" ] ) + if proxy is not None: + for attr, expected_type in attributes_to_check: + if attr in proxy.definition: + if ( expected_type is str and proxy[ attr ] # value is not empty (valid) + and proxy[ attr ] not in array_names # value is not in the expected names + ): + field[ "invalid_properties" ].append( attr ) + elif expected_type is list: + arrays: list[ str ] | None = group_name_ref_array_to_list( proxy[ attr ] ) + if arrays is None: + field[ "invalid_properties" ].append( attr ) + continue + for array_name in arrays: + if array_name not in array_names: + field[ "invalid_properties" ].append( attr ) + break + + if len( field[ "invalid_properties" ] ) != 0: + field[ "valid" ] = FieldStatus.INVALID.value + else: + field.pop( "invalid_properties", None ) + + if field[ "children" ] is not None: + # Parents are only valid if all children are valid + field[ "invalid_children" ] = [] + for child in field[ "children" ]: + self._check_field( child, array_names ) + if child[ "valid" ] == FieldStatus.INVALID.value: + field[ "valid" ] = FieldStatus.INVALID.value + field[ "invalid_children" ].append( child[ "title" ] ) + if len( field[ "invalid_children" ] ) == 0: + field.pop( "invalid_children", None ) + + def _get_array_names( self ) -> list[ str ]: + cellData = self.region_viewer.input.GetCellData() + return [ cellData.GetArrayName( i ) for i in range( cellData.GetNumberOfArrays() ) ] diff --git a/geos-trame/src/geos_trame/app/core.py b/geos-trame/src/geos/trame/app/core.py similarity index 65% rename from geos-trame/src/geos_trame/app/core.py rename to geos-trame/src/geos/trame/app/core.py index a059c62cc..1f66b96ab 100644 --- a/geos-trame/src/geos_trame/app/core.py +++ b/geos-trame/src/geos/trame/app/core.py @@ -1,20 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Lionel Untereiner + from trame.ui.vuetify3 import VAppLayout from trame.decorators import TrameApp from trame.widgets import html, simput from trame.widgets import vuetify3 as vuetify +from trame_server import Server +from trame_server.controller import Controller +from trame_server.state import State from trame_simput import get_simput_manager -from geos_trame import module -from geos_trame.app.deck.tree import DeckTree -from geos_trame.app.ui.editor import DeckEditor -from geos_trame.app.ui.inspector import DeckInspector -from geos_trame.app.ui.plotting import DeckPlotting -from geos_trame.app.ui.timeline import TimelineEditor -from geos_trame.app.ui.viewer.viewer import DeckViewer -from geos_trame.app.ui.alertHandler import AlertHandler +from geos.trame import module +from geos.trame.app.deck.tree import DeckTree +from geos.trame.app.io.data_loader import DataLoader +from geos.trame.app.ui.viewer.regionViewer import RegionViewer +from geos.trame.app.ui.viewer.wellViewer import WellViewer +from geos.trame.app.components.properties_checker import PropertiesChecker +from geos.trame.app.ui.editor import DeckEditor +from geos.trame.app.ui.inspector import DeckInspector +from geos.trame.app.ui.plotting import DeckPlotting +from geos.trame.app.ui.timeline import TimelineEditor +from geos.trame.app.ui.viewer.viewer import DeckViewer +from geos.trame.app.components.alertHandler import AlertHandler import sys @@ -22,8 +30,14 @@ @TrameApp() class GeosTrame: - def __init__( self, server, file_name: str ): - + def __init__( self, server: Server, file_name: str ) -> None: + """Constructor.""" + self.alertHandler: AlertHandler | None = None + self.deckPlotting: DeckPlotting | None = None + self.deckViewer: DeckViewer | None = None + self.deckEditor: DeckEditor | None = None + self.timelineEditor: TimelineEditor | None = None + self.deckInspector: DeckInspector | None = None self.server = server server.enable_module( module ) @@ -49,32 +63,49 @@ def __init__( self, server, file_name: str ): # Tree self.tree = DeckTree( self.state.sm_id ) + # Viewers + self.region_viewer = RegionViewer() + self.well_viewer = WellViewer( 5, 5 ) + + # Data loader + self.data_loader = DataLoader( self.tree, self.region_viewer, self.well_viewer, trame_server=server ) + + # Properties checker + self.properties_checker = PropertiesChecker( self.tree, self.region_viewer, trame_server=server ) + # TODO put as a modal window self.set_input_file( file_name=self.state.input_file ) # Load components - self.ui = self.build_ui() + self.build_ui() @property - def state( self ): + def state( self ) -> State: + """Getter for the state.""" return self.server.state @property - def ctrl( self ): + def ctrl( self ) -> Controller: + """Getter for the controller.""" return self.server.controller - def set_input_file( self, file_name, file_str=None ): - """sets the input file of the InputTree object and populates simput/ui""" + def set_input_file( self, file_name: str ) -> None: + """Sets the input file of the InputTree object and populates simput/ui.""" self.tree.set_input_file( file_name ) - def deck_ui( self ): - """Generates the UI for the deck edition / visualization tab""" + def deck_ui( self ) -> None: + """Generates the UI for the deck edition / visualization tab.""" with vuetify.VRow( classes="mb-6 fill-height" ): with vuetify.VCol( cols=2, order=1, ): - self.deckInspector = DeckInspector( source=self.tree, classes="fill-height" ) + self.deckInspector = DeckInspector( source=self.tree, classes="fit-content" ) + vuetify.VBtn( + text="Check fields", + classes="ma-4", + click=( self.properties_checker.check_fields, ), + ) with vuetify.VCol( cols=10, @@ -99,6 +130,8 @@ def deck_ui( self ): ): self.deckViewer = DeckViewer( source=self.tree, + region_viewer=self.region_viewer, + well_viewer=self.well_viewer, classes="ma-2", style="flex: 1; height: 60%; width: 100%;", ) @@ -109,23 +142,18 @@ def deck_ui( self ): style="flex: 1; height: 40%; width: 100%;", ) - def build_ui( self, *args, **kwargs ): - """Generates the full UI for the GEOS Trame Application""" - + def build_ui( self ) -> None: + """Generates the full UI for the GEOS Trame Application.""" with VAppLayout( self.server ) as layout: self.simput_widget.register_layout( layout ) self.alertHandler = AlertHandler() - def on_tab_change( tab_idx ): - pass - with html.Div( style="position: relative; display: flex; border-bottom: 1px solid gray", ): with vuetify.VTabs( v_model=( "tab_idx", 0 ), style="z-index: 1;", color="grey", - change=( on_tab_change, "[$event]" ), ): for tab_label in [ "Input File", "Execute", "Results Viewer" ]: vuetify.VTab( tab_label ) @@ -134,18 +162,20 @@ def on_tab_change( tab_idx ): style= "position: absolute; top: 0; left: 0; height: 100%; width: 100%; display: flex; align-items: center; justify-content: center;", ): - with html.Div( - v_if=( "tab_idx == 0", ), - style= - "height: 100%; width: 100%; display: flex; align-items: center; justify-content: flex-end;", - ): - with vuetify.VBtn( + with ( + html.Div( + v_if=( "tab_idx == 0", ), + style= + "height: 100%; width: 100%; display: flex; align-items: center; justify-content: flex-end;", + ), + vuetify.VBtn( click=self.tree.write_files, icon=True, style="z-index: 1;", id="save-button", - ): - vuetify.VIcon( "mdi-content-save-outline" ) + ), + ): + vuetify.VIcon( "mdi-content-save-outline" ) with html.Div( style= @@ -154,21 +184,14 @@ def on_tab_change( tab_idx ): ): vuetify.VBtn( "Run", - # click=self.executor.run, - # disabled=( - # "exe_running || exe_use_threading && exe_threads < 2 || exe_use_mpi && exe_processes < 2", - # ), style="z-index: 1;", ) vuetify.VBtn( "Kill", - # click=self.executor.kill, - # disabled=("!exe_running",), style="z-index: 1;", ) vuetify.VBtn( "Clear", - # click=self.ctrl.terminal_clear, style="z-index: 1;", ) diff --git a/geos-trame/src/geos_trame/app/ui/viewer/__init__.py b/geos-trame/src/geos/trame/app/data_types/__init__.py similarity index 100% rename from geos-trame/src/geos_trame/app/ui/viewer/__init__.py rename to geos-trame/src/geos/trame/app/data_types/__init__.py diff --git a/geos-trame/src/geos/trame/app/data_types/field_status.py b/geos-trame/src/geos/trame/app/data_types/field_status.py new file mode 100644 index 000000000..b9ac00a30 --- /dev/null +++ b/geos-trame/src/geos/trame/app/data_types/field_status.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class FieldStatus( Enum ): + UNCHECKED = 0 + VALID = 1 + INVALID = 2 diff --git a/geos-trame/src/geos/trame/app/data_types/renderable.py b/geos-trame/src/geos/trame/app/data_types/renderable.py new file mode 100644 index 000000000..e0312401d --- /dev/null +++ b/geos-trame/src/geos/trame/app/data_types/renderable.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. +# SPDX-FileContributor: Kitware +from enum import Enum + + +class Renderable( Enum ): + """Enum class for renderable types and their ids.""" + VTKMESH = "VTKMesh" + INTERNALMESH = "InternalMesh" + INTERNALWELL = "InternalWell" + PERFORATION = "Perforation" + VTKWELL = "VTKWell" diff --git a/geos-trame/src/geos/trame/app/data_types/tree_node.py b/geos-trame/src/geos/trame/app/data_types/tree_node.py new file mode 100644 index 000000000..f8c7e50fe --- /dev/null +++ b/geos-trame/src/geos/trame/app/data_types/tree_node.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. +# SPDX-FileContributor: Kitware +from dataclasses import dataclass + + +@dataclass +class TreeNode: + """Single element of the tree, used by `DeckTree`. + + `valid` has to be an int for serialization purposes, but is actually a FieldStatus so only possibles values are: + - 0 (UNCHECKED): Validity check has not been performed. + - 1 (VALID): TreeNode is checked and valid. + - 2 (INVALID): TreeNode is checked and invalid. + """ + + id: str + title: str + children: list + hidden_children: list + is_drawable: bool + drawn: bool + valid: int + + @property + def json( self ) -> dict: + """Get the tree node as json.""" + return { + "id": self.id, + "title": self.title, + "is_drawable": self.is_drawable, + "drawn": self.drawn, + "valid": self.valid, + "children": [ c.json for c in self.children ] if self.children else None, + "hidden_children": ( [ c.json for c in self.hidden_children ] if self.hidden_children else [] ), + } diff --git a/geos-trame/src/geos_trame/widgets/__init__.py b/geos-trame/src/geos/trame/app/deck/__init__.py similarity index 100% rename from geos-trame/src/geos_trame/widgets/__init__.py rename to geos-trame/src/geos/trame/app/deck/__init__.py diff --git a/geos-trame/src/geos/trame/app/deck/file.py b/geos-trame/src/geos/trame/app/deck/file.py new file mode 100644 index 000000000..1c3bf6e5b --- /dev/null +++ b/geos-trame/src/geos/trame/app/deck/file.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. +# SPDX-FileContributor: Lionel Untereiner +import os +from typing import Any + +from lxml import etree as ElementTree # type: ignore[import-untyped] +from xsdata.formats.dataclass.parsers.config import ParserConfig +from xsdata.formats.dataclass.serializers.config import SerializerConfig +from xsdata.utils import text +from xsdata_pydantic.bindings import DictEncoder, XmlContext, XmlParser, XmlSerializer + +from geos.trame.app.data_types.renderable import Renderable +from geos.trame.app.geosTrameException import GeosTrameException +from geos.trame.app.io.xml_parser import XMLParser +from geos.trame.app.utils.file_utils import normalize_path +from geos.trame.schema_generated.schema_mod import Problem + + +class DeckFile( object ): + """Holds the information of a deck file. Can be empty.""" + + def __init__( self, filename: str, **kwargs: Any ) -> None: + """Constructor. + + Input: + filename: file name of the deck file + """ + super( DeckFile, self ).__init__( **kwargs ) + + self.inspect_tree: dict[ Any, Any ] | None = None + self.pb_dict: dict[ str, Any ] | None = None + self.problem: Problem | None = None + self.xml_parser: XMLParser | None = None + self.root_node = None + self.filename = normalize_path( filename ) + if self.filename: + self.open_deck_file( self.filename ) + self.original_text = "" + self.changed = False + + self.path = os.path.dirname( self.filename ) + + def open_deck_file( self, filename: str ) -> None: + """Opens a file and parses it. + + Input: + filename: file name of the input file + Signals: + input_file_changed: On success + Raises: + GeosTrameException: On invalid input file + """ + self.changed = False + self.root_node = None + + # Do some basic checks on the filename to make sure + # it is probably a real input file since the GetPot + # parser doesn't do any checks. + if not os.path.exists( filename ): + msg = "Input file %s does not exist" % filename + raise GeosTrameException( msg ) + + if not os.path.isfile( filename ): + msg = "Input file %s is not a file" % filename + raise GeosTrameException( msg ) + + if not filename.endswith( ".xml" ): + msg = "Input file %s does not have the proper extension" % filename + raise GeosTrameException( msg ) + + self.xml_parser = XMLParser( filename=filename ) + self.xml_parser.build() + simulation_deck = self.xml_parser.get_simulation_deck() + + context = XmlContext( + element_name_generator=text.pascal_case, + attribute_name_generator=text.camel_case, + ) + parser = XmlParser( context=context, config=ParserConfig() ) + try: + self.problem = parser.parse( simulation_deck, Problem ) + except ElementTree.XMLSyntaxError as e: + msg = "Failed to parse input file %s:\n%s\n" % ( filename, e ) + raise GeosTrameException( msg ) from e + + encoder = DictEncoder( context=context, config=SerializerConfig( indent=" " ) ) + self.pb_dict = { "Problem": encoder.encode( self.problem ) } + self.inspect_tree = build_inspect_tree( encoder.encode( self.problem ) ) + + def to_str( self ) -> str: + """Get the problem as a string.""" + config = SerializerConfig( indent=" ", xml_declaration=False ) + context = XmlContext( + element_name_generator=text.pascal_case, + attribute_name_generator=text.camel_case, + ) + serializer = XmlSerializer( context=context, config=config ) + return serializer.render( self.problem ) + + +def build_inspect_tree( obj: dict ) -> dict: + """Return the fields of a dataclass instance as a new dictionary mapping field names to field values. + + Example usage:: + + @dataclass + class C: + x: int + y: int + + c = C(1, 2) + assert asdict(c) == {'x': 1, 'y': 2} + + If given, 'dict_factory' will be used instead of built-in dict. + The function applies recursively to field values that are + dataclass instances. This will also look into built-in containers: + tuples, lists, and dicts. Other objects are copied with 'copy.deepcopy()'. + """ + return _build_inspect_tree_inner( "Problem", obj, [] ) + + +def _build_inspect_tree_inner( key: str, obj: dict, path: list ) -> dict: + sub_node = { + "title": obj.get( "name", key ), + "children": [], + "is_drawable": key in ( item.value for item in Renderable ), + "drawn": False, + } + + for key, value in obj.items(): + + if isinstance( value, list ): + for idx, item in enumerate( value ): + if isinstance( item, dict ): + more_results = _build_inspect_tree_inner( key, item, path + [ key ] + [ idx ] ) + # for another_result in more_results: + sub_node[ "children" ].append( more_results ) + + sub_node[ "id" ] = "Problem/" + "/".join( map( str, path ) ) + + return sub_node diff --git a/geos-trame/src/geos_trame/app/deck/tree.py b/geos-trame/src/geos/trame/app/deck/tree.py similarity index 54% rename from geos-trame/src/geos_trame/app/deck/tree.py rename to geos-trame/src/geos/trame/app/deck/tree.py index 708bf34b5..4979f3bec 100644 --- a/geos-trame/src/geos_trame/app/deck/tree.py +++ b/geos-trame/src/geos/trame/app/deck/tree.py @@ -1,53 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Lionel Untereiner -import dpath -import funcy - import os +from collections import defaultdict +from typing import Any +import dpath +import funcy +from pydantic import BaseModel +from trame_simput import get_simput_manager from xsdata.formats.dataclass.parsers.config import ParserConfig from xsdata.formats.dataclass.serializers.config import SerializerConfig from xsdata.utils import text from xsdata_pydantic.bindings import DictDecoder, XmlContext, XmlSerializer -from geos_trame.app.geosTrameException import GeosTrameException -from geos_trame.schema_generated.schema_mod import BaseModel, Problem, Included, File - -from .file import DeckFile, format_xml, normalize_path - -from collections import defaultdict -from trame_simput import get_simput_manager - - -def recursive_dict( element ): - return element.tag, dict( map( recursive_dict, element ) ) or element.text +from geos.trame.app.deck.file import DeckFile +from geos.trame.app.geosTrameException import GeosTrameException +from geos.trame.app.utils.file_utils import normalize_path, format_xml +from geos.trame.schema_generated.schema_mod import Problem, Included, File, Functions class DeckTree( object ): - """ - A tree that represents an deck file along with all the available blocks and parameters. - """ + """A tree that represents a deck file along with all the available blocks and parameters.""" - def __init__( self, sm_id=None, **kwds ): - """ - Constructor. - """ - super( DeckTree, self ).__init__( **kwds ) + def __init__( self, sm_id: str | None = None, **kwargs: Any ) -> None: + """Constructor.""" + super( DeckTree, self ).__init__( **kwargs ) - self.input_file = None - self.input_filename = None - self.input_folder = None - self.input_real_filename = None - # self._copyDefaultTree() + self.input_file: DeckFile | None = None + self.input_filename: str | None = None + self.input_folder: str | None = None self.root = None - self.path_map = {} self.input_has_errors = False self._sm_id = sm_id - def set_input_file( self, input_filename ): - """ - Set a new input file + def set_input_file( self, input_filename: str ) -> None: + """Set a new input file. + Input: input_filename[str]: The name of the input file Return: @@ -57,66 +46,70 @@ def set_input_file( self, input_filename ): self.input_filename = input_filename self.input_file = DeckFile( self.input_filename ) self.input_folder = os.path.dirname( self.input_file.filename ) - self.input_real_filename = os.path.basename( self.input_file.filename ) - except Exception as e: - msg = "set_input_file exception: %s" % e - return GeosTrameException( msg ) - - def root_fields( self ) -> list[ str ]: - return self.input_file.root_fields + except GeosTrameException: + return def get_mesh( self ) -> str: + """Get the path of the mesh.""" + assert self.input_file is not None and self.input_file.problem is not None return normalize_path( self.input_file.path + "/" + self.input_file.problem.mesh[ 0 ].vtkmesh[ 0 ].file ) - def get_abs_path( self, file ) -> str: + def get_abs_path( self, file: str ) -> str: + """Get the absolute path from a path.""" + assert self.input_file is not None and self.input_file.path is not None return normalize_path( self.input_file.path + "/" + file ) def to_str( self ) -> str: + """Get the input file as a string.""" + assert self.input_file is not None return self.input_file.to_str() def get_tree( self ) -> dict: + """Get the tree from the input file.""" + assert self.input_file is not None and self.input_file.inspect_tree is not None return self.input_file.inspect_tree - def update( self, path, key, value ) -> None: + def update( self, path: str, key: str, value: Any ) -> None: + """Update the tree.""" new_path = [ int( x ) if x.isdigit() else x for x in path.split( "/" ) ] new_path.append( key ) + assert self.input_file is not None and self.input_file.pb_dict is not None funcy.set_in( self.input_file.pb_dict, new_path, value ) - def search( self, path ) -> dict: - new_path = [ int( x ) if x.isdigit() else x for x in path.split( "/" ) ] + def _search( self, path: str ) -> list | None: + new_path = path.split( "/" ) if self.input_file is None: - return + return None + assert self.input_file.pb_dict is not None return dpath.values( self.input_file.pb_dict, new_path ) - def decode( self, path ): - data = self.search( path ) + def decode( self, path: str ) -> BaseModel | None: + """Decode the given file to a BaseModel.""" + data = self._search( path ) if data is None: - return + return None context = XmlContext( element_name_generator=text.pascal_case, attribute_name_generator=text.camel_case, ) decoder = DictDecoder( context=context, config=ParserConfig() ) - node = decoder.decode( data[ 0 ] ) - return node - - def decode_data( self, data: BaseModel ) -> str: - """ - Convert a data to a xml serializable file - """ - if data is None: - return + return decoder.decode( data[ 0 ] ) + @staticmethod + def decode_data( data: dict ) -> Problem: + """Convert a data to a xml serializable file.""" context = XmlContext( element_name_generator=text.pascal_case, attribute_name_generator=text.camel_case, ) decoder = DictDecoder( context=context, config=ParserConfig() ) - node = decoder.decode( data ) + node: Problem = decoder.decode( data ) return node - def to_xml( self, obj ) -> str: + @staticmethod + def to_xml( obj: BaseModel ) -> str: + """Convert the given obj to xml.""" context = XmlContext( element_name_generator=text.pascal_case, attribute_name_generator=text.camel_case, @@ -127,94 +120,83 @@ def to_xml( self, obj ) -> str: return format_xml( serializer.render( obj ) ) - def timeline( self ) -> dict: + def timeline( self ) -> list[ dict ] | None: + """Get the timeline.""" if self.input_file is None: - return + return None if self.input_file.problem is None: - return - if self.input_file.problem.events is None: - return + return None - timeline = list() + timeline = [] # list root events global_id = 0 for e in self.input_file.problem.events[ 0 ].periodic_event: - item = dict() - item[ "id" ] = global_id - item[ "summary" ] = e.name - item[ "start_date" ] = e.begin_time + item: dict[ str, str | int ] = { + "id": global_id, + "summary": e.name, + "start_date": e.begin_time, + } timeline.append( item ) global_id = global_id + 1 return timeline - def plots( self ): + def plots( self ) -> list[ Functions ]: + """Get the functions in the current problem.""" + assert self.input_file is not None and self.input_file.problem is not None return self.input_file.problem.functions - def write_files( self ): - """ - Write geos files with all changes made by the user. - """ - - pb = self.search( "Problem" ) + def write_files( self ) -> None: + """Write geos files with all changes made by the user.""" + pb = self._search( "Problem" ) + if pb is None: + return files = self._split( pb ) for filepath, content in files.items(): - model_loaded: BaseModel = self.decode_data( content ) - model_with_changes: BaseModel = self._apply_changed_properties( model_loaded ) + model_loaded: Problem = DeckTree.decode_data( content ) + model_with_changes: Problem = self._apply_changed_properties( model_loaded ) + assert ( self.input_file is not None and self.input_file.xml_parser is not None ) if self.input_file.xml_parser.contains_include_files(): includeName: str = self.input_file.xml_parser.get_relative_path_of_file( filepath ) - self._append_include_file( model_with_changes, includeName ) + DeckTree._append_include_file( model_with_changes, includeName ) - model_as_xml: str = self.to_xml( model_with_changes ) + model_as_xml: str = DeckTree.to_xml( model_with_changes ) basename = os.path.basename( filepath ) + assert self.input_folder is not None edited_folder_path = self.input_folder - location = edited_folder_path + "/" + self._append_id( basename ) + location = edited_folder_path + "/" + DeckTree._append_id( basename ) with open( location, "w" ) as file: file.write( model_as_xml ) file.close() - def _setInputFile( self, input_file ): - """ - Copies the nodes of an input file into the tree - Input: - input_file[InputFile]: Input file to copy - Return: - bool: True if successful - """ - self.input_has_errors = False - if input_file.root_node is None: - return False - self.input_file = input_file - self.input_filename = input_file.filename - - return False + @staticmethod + def _append_include_file( model: Problem, included_file_path: str ) -> None: + """Append an Included object which follows this structure according to the documentation. - def _append_include_file( self, model: Problem, includedFilePath: str ) -> None: - """ - Append an Included object which follows this structure according to the documentation: - Only Problem can contains an included tag: + Only Problem can contain an included tag: https://geosx-geosx.readthedocs-hosted.com/en/latest/docs/sphinx/datastructure/CompleteXMLSchema.html """ - if len( includedFilePath ) == 0: - return None + if len( included_file_path ) == 0: + return includedTag = Included() - includedTag.file.append( File( name=self._append_id( includedFilePath ) ) ) + includedTag.file.append( File( name=DeckTree._append_id( included_file_path ) ) ) model.included.append( includedTag ) - def _append_id( self, filename: str ) -> str: - """ - Return the new filename with the correct suffix and his extension. The suffix - added will be '_vX' where X is the incremented value of the current version. + @staticmethod + def _append_id( filename: str ) -> str: + """Return the new filename with the correct suffix and his extension. + + The suffix added will be '_vX' where X is the incremented value of the current version. '_v0' if any suffix is present. """ name, ext = os.path.splitext( filename ) @@ -231,30 +213,25 @@ def _append_id( self, filename: str ) -> str: suffix += str( version ) return f"{name}{suffix}{ext}" - def _convert_to_camel_case( self, content: str ) -> str: - """ - Convert any given string in CamelCase. + @staticmethod + def _convert_to_camel_case( content: str ) -> str: + """Convert any given string in CamelCase. Useful to transform trame_simput convention in geos schema names convention. """ camel_case_str: str = content.title() return camel_case_str.replace( "_", "" ) - def _convert_to_snake_case( self, content: str ) -> str: - """ - Convert any given string in snake case. + @staticmethod + def _convert_to_snake_case( content: str ) -> str: + """Convert any given string in snake case. Useful to transform geos schema names convention in trame_simput convention. """ return "".join( [ "_" + char.lower() if char.isupper() else char for char in content ] ).lstrip( "_" ) def _apply_changed_properties( self, model: Problem ) -> Problem: - """ - Retrieves all edited 'properties' from the simput_manager and apply it to a - given model. - - """ - + """Retrieves all edited 'properties' from the simput_manager and apply it to a given model.""" manager = get_simput_manager( self._sm_id ) modified_proxy_ids: set[ str ] = manager.proxymanager.dirty_proxy_data @@ -265,39 +242,35 @@ def _apply_changed_properties( self, model: Problem ) -> Problem: for proxy_id in modified_proxy_ids: properties = manager.data( proxy_id )[ "properties" ] - events = self._get_base_model_from_path( model_as_dict, proxy_id ) - if events is None: - continue + events = DeckTree._get_base_model_from_path( model_as_dict, proxy_id ) events_as_dict = dict( events ) for property_name, value in properties.items(): events_as_dict[ property_name ] = value - self._set_base_model_properties( model_as_dict, proxy_id, events_as_dict ) + DeckTree._set_base_model_properties( model_as_dict, proxy_id, events_as_dict ) - model = getattr( model, "model_validate" )( model_as_dict ) + model = model.model_validate( model_as_dict ) return model - def _convert_proxy_path_into_proxy_names( self, proxy_path: str ) -> list[ str ]: - """ - Split a given proxy path into a list of proxy names. + @staticmethod + def _convert_proxy_path_into_proxy_names( proxy_path: str ) -> list[ str ]: + """Split a given proxy path into a list of proxy names. note: each proxy name will be converted in snake case to fit with the pydantic model naming convention. """ - splitted_path = proxy_path.split( "/" ) - splitted_path_without_root = splitted_path[ 1: ] - - return [ self._convert_to_snake_case( proxy ) for proxy in splitted_path_without_root ] + split_path = proxy_path.split( "/" ) + split_path_without_root = split_path[ 1: ] - def _set_base_model_properties( self, model: dict, proxy_path: str, properties: dict ) -> None: - """ - Apply all changed property to the model for a specific proxy. - """ + return [ DeckTree._convert_to_snake_case( proxy ) for proxy in split_path_without_root ] + @staticmethod + def _set_base_model_properties( model: dict, proxy_path: str, properties: dict ) -> None: + """Apply all changed property to the model for a specific proxy.""" # retrieve the whole BaseModel list to the modified proxy - proxy_names = self._convert_proxy_path_into_proxy_names( proxy_path ) + proxy_names = DeckTree._convert_proxy_path_into_proxy_names( proxy_path ) model_copy = model - models = [] + models: list[ tuple[ str, dict ] ] = [] for proxy_name in proxy_names: is_dict = type( model_copy ) is dict is_list = type( model_copy ) is list @@ -307,22 +280,22 @@ def _set_base_model_properties( self, model: dict, proxy_path: str, properties: model_copy = dict( model_copy ) if proxy_name.isnumeric() and int( proxy_name ) < len( model_copy ): - models.append( [ proxy_name, model_copy ] ) - model_copy = model_copy[ int( proxy_name ) ] + models.append( ( proxy_name, model_copy ) ) + model_copy = model_copy[ proxy_name ] continue if proxy_name in model_copy: - models.append( [ proxy_name, model_copy ] ) + models.append( ( proxy_name, model_copy ) ) model_copy = model_copy[ proxy_name ] else: - return None + return models.reverse() # propagate the modification to the parent node index = -1 for model_inverted in models: - prop_identifier = model_inverted[ 0 ] + prop_identifier: str = model_inverted[ 0 ] if prop_identifier.isnumeric(): index = int( prop_identifier ) @@ -333,22 +306,18 @@ def _set_base_model_properties( self, model: dict, proxy_path: str, properties: current_node = model_inverted[ 1 ] current_base_model = current_node[ prop_identifier ][ index ] - current_base_model = getattr( current_base_model, "model_validate" )( properties ) + current_base_model = current_base_model.model_validate( properties ) current_node[ prop_identifier ][ index ] = current_base_model - properties = dict( current_base_model ) break models.reverse() - model = models[ 0 ] - def _get_base_model_from_path( self, model: dict, proxy_id: str ) -> BaseModel: - """ - Retrieve the BaseModel changed from the proxy id. The proxy_id is a unique path - from the simput manager. - """ - proxy_names = self._convert_proxy_path_into_proxy_names( proxy_id ) + @staticmethod + def _get_base_model_from_path( model: dict, proxy_id: str ) -> dict: + """Retrieve the BaseModel changed from the proxy id. The proxy_id is a unique path from the simput manager.""" + proxy_names = DeckTree._convert_proxy_path_into_proxy_names( proxy_id ) model_found: dict = model @@ -361,7 +330,7 @@ def _get_base_model_from_path( self, model: dict, proxy_id: str ) -> BaseModel: if is_class: model_found = dict( model_found ) - # path can contains a numerical index, useful to be sure that each + # path can contain a numerical index, useful to be sure that each # proxy is unique, typically used for a list of proxy located at the same level if proxy_name.isnumeric() and int( proxy_name ) < len( model_found ): model_found = model_found[ int( proxy_name ) ] @@ -372,11 +341,12 @@ def _get_base_model_from_path( self, model: dict, proxy_id: str ) -> BaseModel: return model_found - def _split( self, xml: str ) -> dict[ str, str ]: + def _split( self, xml: list ) -> defaultdict[ str, dict[ str, str ] ]: + assert self.input_file is not None and self.input_file.xml_parser is not None data = self.input_file.xml_parser.file_to_tags - restructured_files = defaultdict( dict ) + restructured_files: defaultdict[ str, dict ] = defaultdict( dict ) for file_path, associated_tags in data.items(): - restructured_files[ file_path ] = dict() + restructured_files[ file_path ] = {} for tag, contents in xml[ 0 ].items(): if len( contents ) == 0: continue @@ -385,14 +355,3 @@ def _split( self, xml: str ) -> dict[ str, str ]: restructured_files[ file_path ][ tag ] = contents return restructured_files - - -if __name__ == "__main__": - import sys - - if len( sys.argv ) < 3: - print( "Usage: " ) - exit( 1 ) - input_file_path = sys.argv[ 2 ] - deck_tree = DeckTree() - deck_tree.setInputFile( input_file_path ) diff --git a/geos-trame/src/geos_trame/app/geosTrameException.py b/geos-trame/src/geos/trame/app/geosTrameException.py similarity index 100% rename from geos-trame/src/geos_trame/app/geosTrameException.py rename to geos-trame/src/geos/trame/app/geosTrameException.py diff --git a/geos-trame/src/geos/trame/app/io/__init__.py b/geos-trame/src/geos/trame/app/io/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/geos-trame/src/geos/trame/app/io/data_loader.py b/geos-trame/src/geos/trame/app/io/data_loader.py new file mode 100644 index 000000000..d5a7bd2f1 --- /dev/null +++ b/geos-trame/src/geos/trame/app/io/data_loader.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. +# SPDX-FileContributor: Kitware +from typing import Type, Any + +import numpy as np +from trame_client.widgets.core import AbstractElement +import pyvista as pv + +from geos.trame.app.deck.tree import DeckTree +from geos.trame.app.geosTrameException import GeosTrameException +from geos.trame.app.ui.viewer.regionViewer import RegionViewer +from geos.trame.app.ui.viewer.wellViewer import WellViewer +from geos.trame.app.utils.pv_utils import read_unstructured_grid +from geos.trame.schema_generated.schema_mod import ( + Vtkmesh, + Vtkwell, + Perforation, + InternalWell, +) + + +class DataLoader( AbstractElement ): + """Helper class to handle IO operations for data loading.""" + + def __init__( + self, + source: DeckTree, + region_viewer: RegionViewer, + well_viewer: WellViewer, + **kwargs: Any, + ) -> None: + """Constructor.""" + super().__init__( "span", **kwargs ) + + self.source = source + self.region_viewer = region_viewer + self.well_viewer = well_viewer + + self.state.change( "object_state" )( self._update_object_state ) + self.ctrl.load_vtkmesh_from_id.add( self.load_vtkmesh_from_id ) + + def load_vtkmesh_from_id( self, node_id: str ) -> None: + """Load the data at the given id if none is already loaded.""" + if self.region_viewer.input.number_of_cells == 0: + active_block = self.source.decode( node_id ) + if isinstance( active_block, Vtkmesh ): + self._read_mesh( active_block ) + + def _update_object_state( self, object_state: tuple[ str, bool ], **_: dict ) -> None: + + path, show_obj = object_state + + if path == "": + return + + active_block = self.source.decode( path ) + + if isinstance( active_block, Vtkmesh ): + self._update_vtkmesh( active_block, show_obj ) + + if isinstance( active_block, Vtkwell ): + if self.region_viewer.input.number_of_cells == 0 and show_obj: + self.ctrl.on_add_warning( + "Can't display " + active_block.name, + "Please display the mesh before creating a well.", + ) + return + + self._update_vtkwell( active_block, path, show_obj ) + + if isinstance( active_block, InternalWell ): + if self.region_viewer.input.number_of_cells == 0 and show_obj: + self.ctrl.on_add_warning( + "Can't display " + active_block.name, + "Please display the mesh before creating a well", + ) + return + + self._update_internalwell( active_block, path, show_obj ) + + if ( isinstance( active_block, Perforation ) and self.well_viewer.get_number_of_wells() == 0 and show_obj ): + self.ctrl.on_add_warning( + "Can't display " + active_block.name, + "Please display a well before creating a perforation", + ) + return + + self.ctrl.update_viewer( active_block, path, show_obj ) + + def _update_vtkmesh( self, mesh: Vtkmesh, show: bool ) -> None: + if not show: + self.region_viewer.reset() + return + + self._read_mesh( mesh ) + + def _read_mesh( self, mesh: Vtkmesh ) -> None: + unstructured_grid = read_unstructured_grid( self.source.get_abs_path( mesh.file ) ) + self.region_viewer.add_mesh( unstructured_grid ) + + def _update_vtkwell( self, well: Vtkwell, path: str, show: bool ) -> None: + if not show: + self.well_viewer.remove( path ) + return + + well_polydata = pv.read( self.source.get_abs_path( well.file ) ) + if not isinstance( well_polydata, pv.PolyData ): + raise GeosTrameException( f"Expected PolyData, got {type(well_polydata).__name__}" ) + self.well_viewer.add_mesh( well_polydata, path ) + + def _update_internalwell( self, well: InternalWell, path: str, show: bool ) -> None: + """Used to control the visibility of the InternalWell. + + This method will create the mesh if it doesn't exist. + """ + if not show: + self.well_viewer.remove( path ) + return + + points = self.__parse_polyline_property( well.polyline_node_coords, dtype=float ) + connectivity = self.__parse_polyline_property( well.polyline_segment_conn, dtype=int ) + connectivity = connectivity.flatten() + + sorted_points = [] + for point_id in connectivity: + sorted_points.append( points[ point_id ] ) + + well_polydata = pv.MultipleLines( sorted_points ) + self.well_viewer.add_mesh( well_polydata, path ) + + @staticmethod + def __parse_polyline_property( polyline_property: str, dtype: Type[ Any ] ) -> np.ndarray: + """Internal method used to parse and convert a property, such as polyline_node_coords, from an InternalWell. + + This string always follow this for : + "{ { 800, 1450, 395.646 }, { 800, 1450, -554.354 } }" + """ + try: + nodes_str = polyline_property.split( "}, {" ) + points = [] + for i in range( 0, len( nodes_str ) ): + + nodes_str[ i ] = nodes_str[ i ].replace( " ", "" ) + nodes_str[ i ] = nodes_str[ i ].replace( "{", "" ) + nodes_str[ i ] = nodes_str[ i ].replace( "}", "" ) + + point = np.array( nodes_str[ i ].split( "," ), dtype=dtype ) + + points.append( point ) + + return np.array( points, dtype=dtype ) + except ValueError as e: + raise GeosTrameException( + "cannot be able to convert the property into a numeric array: ", + ValueError, + ) from e diff --git a/geos-trame/src/geos_trame/app/io/xml_parser.py b/geos-trame/src/geos/trame/app/io/xml_parser.py similarity index 73% rename from geos-trame/src/geos_trame/app/io/xml_parser.py rename to geos-trame/src/geos/trame/app/io/xml_parser.py index a5e9a5958..1bf935128 100644 --- a/geos-trame/src/geos_trame/app/io/xml_parser.py +++ b/geos-trame/src/geos/trame/app/io/xml_parser.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Lionel Untereiner -import os +import sys import re from os.path import expandvars from pathlib import Path @@ -11,25 +11,20 @@ from collections import defaultdict -from geos_trame.app.geosTrameException import GeosTrameException +from geos.trame.app.geosTrameException import GeosTrameException class XMLParser( object ): - """ - Class used to parse a valid XML geos file and construct a link between - each file when they are included. + """Class used to parse a valid XML geos file and construct a link between each file when they are included. - Useful to be able to able to save it later. + Useful to be able to save it later. """ def __init__( self, filename: str ) -> None: - """ - Constructor which takes in input the xml file used to generate pedantic file. - """ - + """Constructor which takes in input the xml file used to generate pedantic file.""" self.filename = filename - self.file_to_tags = defaultdict( list ) - self.file_to_relative_path = {} + self.file_to_tags: defaultdict = defaultdict( list ) + self.file_to_relative_path: dict = {} expanded_file = Path( expandvars( self.filename ) ).expanduser().resolve() self.file_path = expanded_file.parent @@ -40,44 +35,39 @@ def __init__( self, filename: str ) -> None: tree = ElementTree.parse( expanded_file, parser=parser ) self.root = tree.getroot() except XMLSyntaxError as err: - error_msg = "Invalid XML file. Cannot load " + expanded_file + error_msg = "Invalid XML file. Cannot load " + str( expanded_file ) error_msg += ". Outputted error:\n" + err.msg - print( error_msg, file=os.sys.stderr ) + print( error_msg, file=sys.stderr ) self._is_valid = False def is_valid( self ) -> bool: + """Getter for is_valid.""" if not self._is_valid: - print( "XMLParser isn't valid", file=os.sys.stderr ) + print( "XMLParser isn't valid", file=sys.stderr ) return self._is_valid def build( self ) -> None: + """Read the file.""" if not self.is_valid(): raise GeosTrameException( "Cannot parse this file." ) self._read() def get_simulation_deck( self ) -> ElementTree.Element: + """Get the simulation deck.""" if not self.is_valid(): raise GeosTrameException( "Not valid file, cannot return the deck." ) - return return self.simulation_deck def contains_include_files( self ) -> bool: - """ - Return True if the parsed file contains included file or not. - """ + """Return True if the parsed file contains included file or not.""" return len( self.file_to_relative_path ) > 0 def get_relative_path_of_file( self, filename: str ) -> str: - """ - Return the relative path of a given filename. - """ + """Return the relative path of a given filename.""" return self.file_to_relative_path[ filename ] def _read( self ) -> ElementTree.Element: - """Reads an xml file (and recursively its included files) into memory - - Args: - xmlFilepath (str): The path the file to read. + """Reads a xml file (and recursively its included files) into memory. Returns: SimulationDeck: The simulation deck @@ -91,14 +81,14 @@ def _read( self ) -> ElementTree.Element: if include_node.tag == "Included": for f in include_node.findall( "File" ): self.file_to_relative_path[ self.filename ] = f.get( "name" ) - self._merge_included_xml_files( self.root, self.file_path, f.get( "name" ), includeCount ) + self._merge_included_xml_files( self.root, str( self.file_path ), f.get( "name" ), includeCount ) # Remove 'Included' nodes for include_node in self.root.findall( "Included" ): self.root.remove( include_node ) for neighbor in self.root.iter(): - for key in neighbor.attrib.keys(): + for key in neighbor.attrib: # remove unnecessary whitespaces for indentation s = re.sub( r"\s{2,}", " ", neighbor.get( key ) ) neighbor.set( key, s ) @@ -107,30 +97,30 @@ def _read( self ) -> ElementTree.Element: def _merge_xml_nodes( self, - existingNode: ElementTree.Element, - targetNode: ElementTree.Element, + existing_node: ElementTree.Element, + target_node: ElementTree.Element, fname: str, level: int, ) -> None: """Merge nodes in an included file into the current structure level by level. Args: - existingNode (lxml.etree.Element): The current node in the base xml structure. - targetNode (lxml.etree.Element): The node to insert. + existing_node (lxml.etree.Element): The current node in the base xml structure. + target_node (lxml.etree.Element): The node to insert. + fname (str): The target file name. level (int): The xml file depth. """ if not self.is_valid(): raise GeosTrameException( "Not valid file, cannot merge nodes" ) - return # Copy attributes on the current level - for tk in targetNode.attrib.keys(): - existingNode.set( tk, targetNode.get( tk ) ) + for tk in target_node.attrib: + existing_node.set( tk, target_node.get( tk ) ) # Copy target children into the xml structure currentTag = "" matchingSubNodes = [] - for target in targetNode.getchildren(): + for target in target_node.getchildren(): tags = self.file_to_tags[ fname ] tags.append( target.tag ) insertCurrentLevel = True @@ -139,7 +129,7 @@ def _merge_xml_nodes( # exists at this level if currentTag != target.tag: currentTag = target.tag - matchingSubNodes = existingNode.findall( target.tag ) + matchingSubNodes = existing_node.findall( target.tag ) if matchingSubNodes: targetName = target.get( "name" ) @@ -159,42 +149,42 @@ def _merge_xml_nodes( # Insert any unnamed nodes or named nodes that aren't present # in the current xml structure if insertCurrentLevel: - existingNode.insert( -1, target ) + existing_node.insert( -1, target ) def _merge_included_xml_files( self, root: ElementTree.Element, file_path: str, fname: str, - includeCount: int, - maxInclude: int = 100, + include_count: int, + max_include: int = 100, ) -> None: """Recursively merge included files into the current structure. Args: root (lxml.etree.Element): The root node of the base xml structure. + file_path (str): The file path. fname (str): The name of the target xml file to merge. - includeCount (int): The current recursion depth. - maxInclude (int): The maximum number of xml files to include (default = 100) + include_count (int): The current recursion depth. + max_include (int): The maximum number of xml files to include (default = 100) """ if not self.is_valid(): raise GeosTrameException( "Not valid file, cannot merge nodes" ) - return included_file_path = Path( expandvars( file_path ), fname ) expanded_file = included_file_path.expanduser().resolve() self.file_to_relative_path[ fname ] = "" # Check to see if the code has fallen into a loop - includeCount += 1 - if includeCount > maxInclude: + include_count += 1 + if include_count > max_include: raise Exception( "Reached maximum recursive includes... Is there an include loop?" ) # Check to make sure the file exists if not included_file_path.is_file(): print( - "Included file does not exist: %s" % ( included_file_path ), - file=os.sys.stderr, + "Included file does not exist: %s" % included_file_path, + file=sys.stderr, ) raise Exception( "Check included file path!" ) @@ -204,7 +194,7 @@ def _merge_included_xml_files( includeTree = ElementTree.parse( included_file_path, parser ) includeRoot = includeTree.getroot() except XMLSyntaxError as err: - print( "\nCould not load included file: %s" % ( included_file_path ) ) + print( "\nCould not load included file: %s" % included_file_path ) print( err.msg ) raise Exception( "\nCheck included file!" ) from err @@ -212,7 +202,7 @@ def _merge_included_xml_files( for include_node in includeRoot.findall( "Included" ): for f in include_node.findall( "File" ): self.file_to_relative_path[ fname ] = f.get( "name" ) - self._merge_included_xml_files( root, expanded_file.parent, f.get( "name" ), includeCount ) + self._merge_included_xml_files( root, str( expanded_file.parent ), f.get( "name" ), include_count ) # Merge the results into the xml tree self._merge_xml_nodes( root, includeRoot, fname, 0 ) diff --git a/geos-trame/src/geos_trame/app/__main__.py b/geos-trame/src/geos/trame/app/main.py similarity index 78% rename from geos-trame/src/geos_trame/app/__main__.py rename to geos-trame/src/geos/trame/app/main.py index 238394aee..2ad3b293a 100644 --- a/geos-trame/src/geos_trame/app/__main__.py +++ b/geos-trame/src/geos/trame/app/main.py @@ -2,13 +2,16 @@ # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Lionel Untereiner from pathlib import Path +from typing import Any -from trame.app import get_server +from trame.app import get_server # type: ignore +from trame_server import Server -from geos_trame.app.core import GeosTrame +from geos.trame.app.core import GeosTrame -def main( server=None, **kwargs ): +def main( server: Server = None, **kwargs: Any ) -> None: + """Main function.""" # Get or create server if server is None: server = get_server() diff --git a/geos-trame/src/geos/trame/app/ui/__init__.py b/geos-trame/src/geos/trame/app/ui/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/geos-trame/src/geos_trame/app/ui/editor.py b/geos-trame/src/geos/trame/app/ui/editor.py similarity index 87% rename from geos-trame/src/geos_trame/app/ui/editor.py rename to geos-trame/src/geos/trame/app/ui/editor.py index a693da434..dc5b4f5b2 100644 --- a/geos-trame/src/geos_trame/app/ui/editor.py +++ b/geos-trame/src/geos/trame/app/ui/editor.py @@ -1,14 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Lionel Untereiner +from typing import Any + from trame.widgets import code, simput from trame.widgets import vuetify3 as vuetify from trame_simput import get_simput_manager +from geos.trame.app.deck.tree import DeckTree + class DeckEditor( vuetify.VCard ): - def __init__( self, source=None, **kwargs ): + def __init__( self, source: DeckTree, **kwargs: Any ) -> None: + """Constructor.""" super().__init__( **kwargs ) self.tree = source @@ -19,7 +24,7 @@ def __init__( self, source=None, **kwargs ): self.state.active_name = "Problem" self.state.active_snippet = "" - self.state.change( "active_id" )( self.on_active_id ) + self.state.change( "active_id" )( self._on_active_id ) with self: with vuetify.VCardTitle( "Components editor" ): @@ -68,7 +73,7 @@ def __init__( self, source=None, **kwargs ): textmate=( "editor_textmate", None ), ) - def on_active_id( self, active_id, **kwargs ): + def _on_active_id( self, active_id: str | None, **_: Any ) -> None: # this function triggers when a block is selected from the tree in the ui if active_id is None: @@ -86,7 +91,7 @@ def on_active_id( self, active_id, **kwargs ): self.state.active_id = active_id self.state.active_ids = [ active_id ] - if hasattr( active_block, "name" ): + if active_block is not None and hasattr( active_block, "name" ): self.state.active_name = active_block.name else: self.state.active_name = None @@ -97,4 +102,4 @@ def on_active_id( self, active_id, **kwargs ): self.state.active_type = simput_type self.state.active_types = [ simput_type ] - self.state.active_snippet = self.tree.to_xml( active_block ) + self.state.active_snippet = DeckTree.to_xml( active_block ) diff --git a/geos-trame/src/geos/trame/app/ui/inspector.py b/geos-trame/src/geos/trame/app/ui/inspector.py new file mode 100644 index 000000000..2fad1245e --- /dev/null +++ b/geos-trame/src/geos/trame/app/ui/inspector.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. +# SPDX-FileContributor: Lionel Untereiner +from typing import Any + +import yaml +from pydantic import BaseModel +from trame.widgets import vuetify3 as vuetify, html +from trame_simput import get_simput_manager + +from geos.trame.app.data_types.field_status import FieldStatus +from geos.trame.app.data_types.renderable import Renderable +from geos.trame.app.data_types.tree_node import TreeNode +from geos.trame.app.deck.tree import DeckTree +from geos.trame.app.utils.dict_utils import iterate_nested_dict +from geos.trame.schema_generated.schema_mod import Problem + +vuetify.enable_lab() + + +class DeckInspector( vuetify.VTreeview ): + + def __init__( self, source: DeckTree, listen_to_active: bool = True, **kwargs: Any ) -> None: + """Constructor.""" + super().__init__( + # data + items=( "deck_tree", ), + item_value="id", + **{ + # style + "hoverable": True, + "max_width": 500, + "rounded": True, + # activation logic + "activatable": True, + "activated": ( "active_ids", ), + "active_strategy": "single-independent", + "update_activated": ( self.change_current_id, "$event" ), + # selection logic + "selectable": False, + **kwargs, + }, + ) + self.tree = source + self._source: dict | None = None + self.listen_to_active = listen_to_active + + self.state.object_state = ( "", False ) + + # register used types from Problem + self.simput_types: list = [] + + self.simput_manager = get_simput_manager( id=self.state.sm_id ) + + if source.input_file is None: + return + + self._set_source( source.input_file.problem ) + + def _on_change( topic: str, ids: list | None = None ) -> None: + if ids is not None and topic == "changed": + for obj_id in ids: + proxy = self.simput_manager.proxymanager.get( obj_id ) + self.tree.decode( obj_id ) + for prop in proxy.edited_property_names: + self.tree.update( obj_id, prop, proxy.get_property( prop ) ) + + self.simput_manager.proxymanager.on( _on_change ) + + with self, vuetify.Template( v_slot_append="{ item }" ): + with vuetify.VTooltip( v_if=( "item.valid == 2", ) ): + with vuetify.Template( + v_slot_activator=( "{ props }", ), + __properties__=[ ( "v_slot_activator", "v-slot:activator" ) ], + ): + vuetify.VIcon( v_bind=( "props", ), classes="mr-2", icon="mdi-close", color="red" ) + html.Div( + v_if=( "item.invalid_properties", ), + v_text=( "'Invalid properties: ' + item.invalid_properties", ), + ) + html.Div( + v_if=( "item.invalid_children", ), + v_text=( "'Invalid children: ' + item.invalid_children", ), + ) + + vuetify.VIcon( + v_if=( "item.valid < 2", ), + classes="mr-2", + icon="mdi-check", + color=( "['gray', 'green'][item.valid]", ), + ) + vuetify.VCheckboxBtn( + v_if="item.is_drawable", + focused=True, + dense=True, + hide_details=True, + icon=True, + false_icon="mdi-eye-off", + true_icon="mdi-eye", + update_modelValue=( self._to_draw_change, "[ item.id, $event ] " ), + ) + + def _to_draw_change( self, item_id: str, drawn: bool ) -> None: + self.state.object_state = ( item_id, drawn ) + + @property + def source( self ) -> dict | None: + """Getter for source.""" + return self._source + + # TODO + # v should be a proxy like the one in paraview simple + # maybe it can be Any of schema_mod (e.g. Problem) + def _set_source( self, v: Problem | None ) -> None: + + # TODO replace this snippet + from xsdata.formats.dataclass.serializers.config import SerializerConfig + from xsdata.utils import text + from xsdata_pydantic.bindings import DictEncoder, XmlContext + + context = XmlContext( + element_name_generator=text.pascal_case, + attribute_name_generator=text.camel_case, + ) + + encoder = DictEncoder( context=context, config=SerializerConfig( indent=" " ) ) + self._source = encoder.encode( v ) + assert self._source is not None + # with this one by passing v as Problem + # self._source = v + + if v is None: + self.state.deck_tree = [] + else: + self.state.deck_tree = _object_to_tree( self._source ).get( "children", [] ) + + for path in iterate_nested_dict( self.state.deck_tree ): + + active_block = self.tree.decode( path ) + # active_name = None + + # if hasattr(active_block, "name"): + # active_name = active_block.name + + simput_type = type( active_block ).__name__ + + test = _dump( active_block ) + + if test: + params_dict = {} + for key, _ in test.items(): + params_dict[ key ] = { + "type": "string", + } + + self.simput_types.append( simput_type ) + yaml_str = yaml.dump( { simput_type: params_dict }, sort_keys=False ) + + self.simput_manager.load_model( yaml_content=yaml_str ) + + debug = self.simput_manager.proxymanager.create( simput_type, proxy_id=path ) + + for key, _ in test.items(): + debug.set_property( key, getattr( active_block, key ) ) + debug.commit() + + def change_current_id( self, item_id: str | None = None ) -> None: + """Change the current id of the tree. + + This function is called when the user click on the tree. + """ + if item_id is None: + # Silently ignore, it could occur if the user click on the tree + # and this item is already selected + return + + self.state.active_id = item_id + + +def _get_node_dict( obj: dict, node_id: str, path: list ) -> TreeNode: + children = [] + for key, value in obj.items(): + # todo look isinstance(value, dict): + if isinstance( value, list ): + for idx, item in enumerate( value ): + if isinstance( item, dict ): + children.append( _get_node_dict( item, key, path + [ key ] + [ idx ] ) ) + + node_name = node_id + if "name" in obj: + node_name = obj[ "name" ] + + return TreeNode( + id="Problem/" + "/".join( map( str, path ) ), + title=node_name, + children=children if len( children ) else [], + hidden_children=[], + is_drawable=node_id in ( k.value for k in Renderable ), + drawn=False, + valid=FieldStatus.UNCHECKED.value, + ) + + +def _object_to_tree( obj: dict ) -> dict: + return _get_node_dict( obj, "Problem", [] ).json + + +def _dump( item: Any ) -> dict[ str, Any ] | None: + if isinstance( item, BaseModel ): + subitems: dict[ str, Any ] = {} + + for field, value in item: + + if isinstance( value, str ): + subitems[ field ] = value + continue + return subitems + elif isinstance( item, ( list, tuple, set ) ): # pyright: ignore + # Pyright finds this disgusting; this passes `mypy` though. ` # type: + # ignore` would fail `mypy` is it'd be unused (because there's nothing to + # ignore because `mypy` is content) + # return type(container)( # pyright: ignore + # _dump(i) for i in container # pyright: ignore + # ) + return None + elif isinstance( item, dict ): + # return { + # k: _dump(v) + # for k, v in item.items() # pyright: ignore[reportUnknownVariableType] + # } + return None + else: + return item diff --git a/geos-trame/src/geos_trame/app/ui/plotting.py b/geos-trame/src/geos/trame/app/ui/plotting.py similarity index 77% rename from geos-trame/src/geos_trame/app/ui/plotting.py rename to geos-trame/src/geos/trame/app/ui/plotting.py index ff7bcab8c..8f166a043 100644 --- a/geos-trame/src/geos_trame/app/ui/plotting.py +++ b/geos-trame/src/geos/trame/app/ui/plotting.py @@ -1,15 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Lionel Untereiner +from typing import Any + import matplotlib.pyplot as plt import numpy as np +from matplotlib.figure import Figure from trame.widgets import matplotlib from trame.widgets import vuetify3 as vuetify +from geos.trame.app.deck.tree import DeckTree + class DeckPlotting( vuetify.VCard ): - def __init__( self, source=None, **kwargs ): + def __init__( self, source: DeckTree, **kwargs: Any ) -> None: + """Constructor.""" super().__init__( **kwargs ) self._source = source @@ -18,8 +24,8 @@ def __init__( self, source=None, **kwargs ): self._filepath = ( source.input_file.path, ) - self.ctrl.permeability = self.permeability - self.ctrl.figure_size = self.figure_size + self.ctrl.permeability = self._permeability + self.ctrl.figure_size = self._figure_size with self: vuetify.VCardTitle( "2D View" ) @@ -29,13 +35,14 @@ def __init__( self, source=None, **kwargs ): self.ctrl.update_figure = html_viewX.update @property - def source( self ): + def source( self ) -> DeckTree: + """Getter for source.""" return self._source - def update_view( self, **kwargs ): + def _update_view( self ) -> None: self.ctrl.view_update( figure=self.ctrl.permeability( **self.ctrl.figure_size() ) ) - def figure_size( self ): + def _figure_size( self ) -> dict: if self.state.figure_size is None: return {} @@ -50,23 +57,27 @@ def figure_size( self ): "dpi": dpi, } - def inverse_gaz( self, x ): + @staticmethod + def _inverse_gaz( x: np.ndarray ) -> np.ndarray: return 1 - x - def permeability( self, **kwargs ): + def _permeability( self, **kwargs: Any ) -> Figure: # read data + assert self.source.input_file is not None for f in self.source.plots(): for t in f.table_function: if t.name == "waterRelativePermeabilityTable": fileX = t.coordinate_files.strip( "{(.+)}" ).strip() + assert fileX is not None and t.voxel_file is not None self.water_x = np.loadtxt( self.source.input_file.path + "/" + fileX ) self.water_y = np.loadtxt( self.source.input_file.path + "/" + t.voxel_file ) if t.name == "gasRelativePermeabilityTable": fileX = t.coordinate_files.strip( "{(.+)}" ).strip() + assert fileX is not None and t.voxel_file is not None gaz_x = np.loadtxt( self.source.input_file.path + "/" + fileX ) - self.gaz_x = self.inverse_gaz( gaz_x ) + self.gaz_x = self._inverse_gaz( gaz_x ) self.gaz_y = np.loadtxt( self.source.input_file.path + "/" + t.voxel_file ) # make drawing diff --git a/geos-trame/src/geos_trame/app/ui/timeline.py b/geos-trame/src/geos/trame/app/ui/timeline.py similarity index 60% rename from geos-trame/src/geos_trame/app/ui/timeline.py rename to geos-trame/src/geos/trame/app/ui/timeline.py index 7dba36672..d6961c0ed 100644 --- a/geos-trame/src/geos_trame/app/ui/timeline.py +++ b/geos-trame/src/geos/trame/app/ui/timeline.py @@ -1,46 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Lionel Untereiner -from trame.widgets import code, gantt, html, simput +from typing import Any + +from trame.widgets import gantt from trame.widgets import vuetify3 as vuetify from trame_simput import get_simput_manager +from geos.trame.app.deck.tree import DeckTree + class TimelineEditor( vuetify.VCard ): - def __init__( self, source=None, **kwargs ): + def __init__( self, source: DeckTree, **kwargs: Any ) -> None: + """Constructor.""" super().__init__( **kwargs ) self.tree = source self.simput_manager = get_simput_manager( id=self.state.sm_id ) items = self.tree.timeline() - # print(items) - - # DRAFT - # items = [ - # {"id": 1, - # "summary": "outputInjectionPeriod", - # "start_date": "2024-11-02 00:00", - # "end_date": "2024-11-02 00:00", - # "duration": 23 - # }, - # { - # "id": 2, - # "summary": "This is a task with a longer description.", - # "start_date": "2024-11-03 00:00", - # "end_date": "2024-11-04 00:00", - # "duration": 1 - # } - # ] - - items_alt = [ { - "id": 3, - "summary": "Lorem ipsum.", - "start_date": "2024-11-07 00:00", - "end_date": "2024-11-09 00:00", - "duration": 2, - } ] fields = [ { "summary": { @@ -72,12 +51,6 @@ def __init__( self, source=None, **kwargs ): } ] with self: - # with vuetify.VRow( - # rows=2, - # style="width: 100%;", - # dense=True - # classes="fill-height" - # ): vuetify.VCardTitle( "Events View" ) vuetify.VDateInput( label="Select starting simulation date", @@ -86,16 +59,18 @@ def __init__( self, source=None, **kwargs ): placeholder="09/18/2024", ) vuetify.VDivider() - with vuetify.VContainer( "Events timeline" ): - with vuetify.VTimeline( + with ( + vuetify.VContainer( "Events timeline" ), + vuetify.VTimeline( direction="horizontal", truncate_line="both", align="center", side="end", - ): # , truncate_line="both", side="end", line_inset="12"): - with vuetify.VTimelineItem( v_for=( f"item in {items}", ), key="i", value="item", size="small" ): - vuetify.VAlert( "{{ item.summary }}" ) - vuetify.Template( "{{ item.start_date }}", raw_attrs=[ "v-slot:opposite" ] ) + ), + vuetify.VTimelineItem( v_for=( f"item in {items}", ), key="i", value="item", size="small" ), + ): + vuetify.VAlert( "{{ item.summary }}" ) + vuetify.Template( "{{ item.start_date }}", raw_attrs=[ "v-slot:opposite" ] ) with vuetify.VContainer( "Events chart" ): gantt.Gantt( @@ -110,5 +85,6 @@ def __init__( self, source=None, **kwargs ): classes="fill_height", ) - def update_from_js( self, *items ): + def update_from_js( self, *items: tuple ) -> None: + """Update method called from javascript.""" self.state.items = list( items ) diff --git a/geos-trame/src/geos/trame/app/ui/viewer/__init__.py b/geos-trame/src/geos/trame/app/ui/viewer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/geos-trame/src/geos_trame/app/ui/viewer/perforationViewer.py b/geos-trame/src/geos/trame/app/ui/viewer/perforationViewer.py similarity index 70% rename from geos-trame/src/geos_trame/app/ui/viewer/perforationViewer.py rename to geos-trame/src/geos/trame/app/ui/viewer/perforationViewer.py index 28678c5e3..d9db51f9f 100644 --- a/geos-trame/src/geos_trame/app/ui/viewer/perforationViewer.py +++ b/geos-trame/src/geos/trame/app/ui/viewer/perforationViewer.py @@ -5,15 +5,14 @@ class PerforationViewer: - """ - Class representing how storing a GEOS Perforation. - - A perforation is represented by 2 meshes: - _perforation_mesh : which is a sphere located where the perforation is - _extracted_cell : the extracted cell at the perforation location - """ def __init__( self, mesh: pv.PolyData, center: list[ float ], radius: float, actor: pv.Actor ) -> None: + """Class representing how storing a GEOS Perforation. + + A perforation is represented by 2 meshes: + _perforation_mesh : which is a sphere located where the perforation is + _extracted_cell : the extracted cell at the perforation location + """ self.perforation_mesh: pv.PolyData = mesh self.center: list[ float ] = center self.radius: float = radius @@ -21,18 +20,22 @@ def __init__( self, mesh: pv.PolyData, center: list[ float ], radius: float, act self.extracted_cell: pv.Actor def add_extracted_cell( self, cell_actor: pv.Actor ) -> None: + """Set the extracted cell to the given actor.""" self.extracted_cell = cell_actor def update_perforation_radius( self, value: float ) -> None: + """Update the perforation radius with the given value.""" self.radius = value self.perforation_mesh = pv.Sphere( radius=self.radius, center=self.center ) self.perforation_actor.GetMapper().SetInputDataObject( self.perforation_mesh ) self.perforation_actor.GetMapper().Update() def get_perforation_size( self ) -> float: + """Get the perforation radius.""" return self.radius def reset( self ) -> None: + """Reset the mesh, actor, and extracted cell.""" self.perforation_actor = pv.Actor() self.perforation_mesh = pv.PolyData() self.extracted_cell = pv.Actor() diff --git a/geos-trame/src/geos_trame/app/ui/viewer/regionViewer.py b/geos-trame/src/geos/trame/app/ui/viewer/regionViewer.py similarity index 67% rename from geos-trame/src/geos_trame/app/ui/viewer/regionViewer.py rename to geos-trame/src/geos/trame/app/ui/viewer/regionViewer.py index 0ab47c7fd..8929bf675 100644 --- a/geos-trame/src/geos_trame/app/ui/viewer/regionViewer.py +++ b/geos-trame/src/geos/trame/app/ui/viewer/regionViewer.py @@ -5,27 +5,30 @@ class RegionViewer: - """ - Stores all related data information to represent the whole mesh. - - This mesh is represented in GEOS with a Region. - """ def __init__( self ) -> None: - self.input: pv.UnstructuredGrid - self.clip: pv.UnstructuredGrid + """Stores all related data information to represent the whole mesh. + + This mesh is represented in GEOS with a Region. + """ + self.input = pv.UnstructuredGrid() + self.clip = self.input self.reset() def __call__( self, normal: tuple[ float ], origin: tuple[ float ] ) -> None: + """Update clip.""" self.update_clip( normal, origin ) def add_mesh( self, mesh: pv.UnstructuredGrid ) -> None: + """Set the input to the given mesh.""" self.input = mesh # type: ignore self.clip = self.input.copy() # type: ignore def update_clip( self, normal: tuple[ float ], origin: tuple[ float ] ) -> None: + """Update the current clip with the given normal and origin.""" self.clip.copy_from( self.input.clip( normal=normal, origin=origin, crinkle=True ) ) # type: ignore def reset( self ) -> None: + """Reset the input mesh and clip.""" self.input = pv.UnstructuredGrid() self.clip = self.input diff --git a/geos-trame/src/geos/trame/app/ui/viewer/viewer.py b/geos-trame/src/geos/trame/app/ui/viewer/viewer.py new file mode 100644 index 000000000..4b8495c03 --- /dev/null +++ b/geos-trame/src/geos/trame/app/ui/viewer/viewer.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. +# SPDX-FileContributor: Lucas Givord - Kitware +from typing import Any + +import pyvista as pv +from pydantic import BaseModel +from pyvista.trame.ui import plotter_ui +from trame.widgets import html +from trame.widgets import vuetify3 as vuetify +from vtkmodules.vtkRenderingCore import vtkActor + +from geos.trame.app.deck.tree import DeckTree +from geos.trame.app.ui.viewer.perforationViewer import PerforationViewer +from geos.trame.app.ui.viewer.regionViewer import RegionViewer +from geos.trame.app.ui.viewer.wellViewer import WellViewer +from geos.trame.schema_generated.schema_mod import ( + Vtkmesh, + Vtkwell, + Perforation, + InternalWell, +) + +pv.OFF_SCREEN = True + + +class DeckViewer( vuetify.VCard ): + + def __init__( + self, + source: DeckTree, + region_viewer: RegionViewer, + well_viewer: WellViewer, + **kwargs: Any, + ) -> None: + """Deck representing the 3D View using PyVista. + + This view can show: + - Vtkmesh, + - Vtkwell, + - Perforation, + - InternalWell + + Everything is handle in the method 'update_viewer()' which is trigger when the + 'state.object_state' changed (see DeckTree). + + This View handle widgets, such as clip widget or slider to control Wells or + Perforation settings. + """ + super().__init__( **kwargs ) + + self._source = source + self._pl = pv.Plotter() + + self.CUT_PLANE = "on_cut_plane_visibility_change" + self.ZAMPLIFICATION = "_z_amplification" + self.server.state[ self.CUT_PLANE ] = True + self.server.state[ self.ZAMPLIFICATION ] = 1 + + self.region_engine = region_viewer + self.well_engine = well_viewer + self._perforations: dict[ str, PerforationViewer ] = {} + + self.ctrl.update_viewer.add( self.update_viewer ) + + with self: + vuetify.VCardTitle( "3D View" ) + view = plotter_ui( + self._pl, + add_menu_items=self.rendering_menu_extra_items, + style="position: absolute;", + ) + self.ctrl.view_update = view.update + + @property + def plotter( self ) -> pv.Plotter: + """Getter for plotter.""" + return self._pl + + @property + def source( self ) -> DeckTree: + """Getter for source.""" + return self._source + + def rendering_menu_extra_items( self ) -> None: + """Extend the default pyvista menu with custom button. + + For now, adding a button to show/hide all widgets. + """ + self.state.change( self.CUT_PLANE )( self._on_clip_visibility_change ) + vuetify.VDivider( vertical=True, classes="mr-1" ) + with vuetify.VTooltip( location="bottom" ): + with ( + vuetify.Template( v_slot_activator=( "{ props }", ) ), + html.Div( v_bind=( "props", ) ), + ): + vuetify.VCheckbox( + v_model=( self.CUT_PLANE, True ), + icon=True, + true_icon="mdi-eye", + false_icon="mdi-eye-off", + dense=True, + hide_details=True, + ) + html.Span( "Show/Hide widgets" ) + + def update_viewer( self, active_block: BaseModel, path: str, show_obj: bool ) -> None: + """Add from path the dataset given by the user. + + Supported data type is: Vtkwell, Vtkmesh, InternalWell, Perforation. + + object_state : array used to store path to the data and if we want to show it or not. + """ + if isinstance( active_block, Vtkmesh ): + self._update_vtkmesh( show_obj ) + + if isinstance( active_block, Vtkwell ): + self._update_vtkwell( path, show_obj ) + + if isinstance( active_block, InternalWell ): + self._update_internalwell( path, show_obj ) + + if isinstance( active_block, Perforation ): + self._update_perforation( active_block, show_obj, path ) + + def _on_clip_visibility_change( self, **kwargs: Any ) -> None: + """Toggle cut plane visibility for all actors. + + Parameters + ---------- + **kwargs : dict, optional + Unused keyword arguments. + + """ + show_widgets = kwargs[ self.CUT_PLANE ] + if show_widgets: + self._setup_slider() + else: + self._remove_slider() + + if self.plotter.plane_widgets: + widgets = self.plotter.plane_widgets + widgets[ 0 ].SetEnabled( show_widgets ) + self.plotter.render() + + def _setup_slider( self ) -> None: + """Create slider to control in the gui well parameters.""" + wells_radius = self._get_tube_size() + self.plotter.add_slider_widget( + self._on_change_tube_size, + [ 1, 20 ], + title="Wells radius", + pointa=( 0.02, 0.12 ), + pointb=( 0.30, 0.12 ), + title_opacity=0.5, + title_color="black", + title_height=0.02, + value=wells_radius, + ) + + perforation_radius = self._get_perforation_size() + self.plotter.add_slider_widget( + self._on_change_perforation_size, + [ 1, 50 ], + title="Perforation radius", + title_opacity=0.5, + pointa=( 0.02, 0.25 ), + pointb=( 0.30, 0.25 ), + title_color="black", + title_height=0.02, + value=perforation_radius, + ) + + def _remove_slider( self ) -> None: + """Create slider to control in the gui well parameters.""" + self.plotter.clear_slider_widgets() + + def _on_change_tube_size( self, value: float ) -> None: + self.well_engine.update( value ) + + def _get_tube_size( self ) -> float: + return self.well_engine.get_tube_size() + + def _on_change_perforation_size( self, value: float ) -> None: + for _, perforation in self._perforations.items(): + perforation.update_perforation_radius( value ) + + def _get_perforation_size( self ) -> float | None: + if len( self._perforations ) <= 0: + return 5.0 + + for _, perforation in self._perforations.items(): + return perforation.get_perforation_size() + return None + + def _update_internalwell( self, path: str, show: bool ) -> None: + """Used to control the visibility of the InternalWell. + + This method will create the mesh if it doesn't exist. + """ + if not show: + self.plotter.remove_actor( self.well_engine.get_actor( path ) ) # type: ignore + return + + tube_actor = self.plotter.add_mesh( self.well_engine.get_tube( self.well_engine.get_last_mesh_idx() ) ) + self.well_engine.append_actor( path, tube_actor ) + + self.server.controller.view_update() + + def _update_vtkwell( self, path: str, show: bool ) -> None: + """Used to control the visibility of the Vtkwell. + + This method will create the mesh if it doesn't exist. + """ + if not show: + self.plotter.remove_actor( self.well_engine.get_actor( path ) ) # type: ignore + return + + tube_actor = self.plotter.add_mesh( self.well_engine.get_tube( self.well_engine.get_last_mesh_idx() ) ) + self.well_engine.append_actor( path, tube_actor ) + + self.server.controller.view_update() + + def _update_vtkmesh( self, show: bool ) -> None: + """Used to control the visibility of the Vtkmesh. + + This method will create the mesh if it doesn't exist. + + Additionally, a clip filter will be added. + """ + if not show: + self.plotter.clear_plane_widgets() + self.plotter.remove_actor( self._clip_mesh ) # type: ignore + return + + active_scalar = self.region_engine.input.active_scalars_name + self._clip_mesh: vtkActor = self.plotter.add_mesh_clip_plane( + self.region_engine.input, + origin=self.region_engine.input.center, + normal=[ -1, 0, 0 ], + crinkle=True, + show_edges=False, + cmap="glasbey_bw", + scalars=active_scalar, + ) + + self.server.controller.view_update() + + def _update_perforation( self, perforation: Perforation, show: bool, path: str ) -> None: + """Generate VTK dataset from a perforation.""" + if not show: + if path in self._perforations: + self._remove_perforation( path ) + return + + distance_from_head = float( perforation.distance_from_head ) + self._add_perforation( distance_from_head, path ) + + def _remove_perforation( self, path: str ) -> None: + """Remove all actor related to the given path and clean the stored perforation.""" + saved_perforation: PerforationViewer = self._perforations[ path ] + self.plotter.remove_actor( saved_perforation.extracted_cell ) # type: ignore + self.plotter.remove_actor( saved_perforation.perforation_actor ) # type: ignore + saved_perforation.reset() + + def _add_perforation( self, distance_from_head: float, path: str ) -> None: + """Generate perforation dataset based on the distance from the top of a polyline.""" + polyline: pv.PolyData | None = self.well_engine.get_mesh( path ) + if polyline is None: + return + + point = polyline.points[ 0 ] + point_offsetted = [ + point[ 0 ], + point[ 1 ], + point[ 2 ] - distance_from_head, + ] + + center = [ + float( point[ 0 ] ), + float( point[ 1 ] ), + point[ 2 ] - float( distance_from_head ), + ] + sphere = pv.Sphere( radius=5, center=center ) + + perforation_actor = self.plotter.add_mesh( sphere ) + saved_perforation = PerforationViewer( sphere, center, 5, perforation_actor ) + + cell_id = self.region_engine.input.find_closest_cell( point_offsetted ) + cell = self.region_engine.input.extract_cells( [ cell_id ] ) + cell_actor = self.plotter.add_mesh( cell ) + saved_perforation.add_extracted_cell( cell_actor ) + + self._perforations[ path ] = saved_perforation diff --git a/geos-trame/src/geos_trame/app/ui/viewer/wellViewer.py b/geos-trame/src/geos/trame/app/ui/viewer/wellViewer.py similarity index 75% rename from geos-trame/src/geos_trame/app/ui/viewer/wellViewer.py rename to geos-trame/src/geos/trame/app/ui/viewer/wellViewer.py index ece7d5482..4196f7b87 100644 --- a/geos-trame/src/geos_trame/app/ui/viewer/wellViewer.py +++ b/geos-trame/src/geos/trame/app/ui/viewer/wellViewer.py @@ -8,8 +8,7 @@ @dataclass class Well: - """ - A Well is represented by a polyline and a tube. + """A Well is represented by a polyline and a tube. This class stores also the related actor and his given path to simplify data management. @@ -22,13 +21,12 @@ class Well: class WellViewer: - """ - WellViewer stores all Well used in the pv.Plotter(). - - A Well in GEOS could a InternalWell or a Vtkwell. - """ def __init__( self, size: float, amplification: float ) -> None: + """WellViewer stores all Well used in the pv.Plotter(). + + A Well in GEOS could a InternalWell or a Vtkwell. + """ self._wells: list[ Well ] = [] self.size: float = size @@ -36,11 +34,16 @@ def __init__( self, size: float, amplification: float ) -> None: self.STARTING_VALUE: float = 5.0 def __call__( self, value: float ) -> None: + """Call update.""" self.update( value ) + def get_last_mesh_idx( self ) -> int: + """Returns the index of the last mesh.""" + return len( self._wells ) - 1 + def add_mesh( self, mesh: pv.PolyData, mesh_path: str ) -> int: - """ - Store a given mesh representing a polyline. + """Store a given mesh representing a polyline. + This polyline will be used then to create a tube to represent this line. return the indexed position of the stored well. @@ -53,9 +56,7 @@ def add_mesh( self, mesh: pv.PolyData, mesh_path: str ) -> int: return len( self._wells ) - 1 def get_mesh( self, perforation_path: str ) -> pv.PolyData | None: - """ - Retrieve the polyline linked to a given perforation path. - """ + """Retrieve the polyline linked to a given perforation path.""" index = self._get_index_from_perforation( perforation_path ) if index == -1: print( "Cannot found the well to remove from path: ", perforation_path ) @@ -64,9 +65,7 @@ def get_mesh( self, perforation_path: str ) -> pv.PolyData | None: return self._wells[ index ].polyline def get_tube( self, index: int ) -> pv.PolyData | None: - """ - Retrieve the polyline linked to a given perforation path. - """ + """Retrieve the polyline linked to a given perforation path.""" if index < 0 or index > len( self._wells ): print( "Cannot get the tube at index: ", index ) return None @@ -74,28 +73,20 @@ def get_tube( self, index: int ) -> pv.PolyData | None: return self._wells[ index ].tube def get_tube_size( self ) -> float: - """ - get the size used for the tube. - """ + """Get the size used for the tube.""" return self.size def append_actor( self, perforation_path: str, tube_actor: pv.Actor ) -> None: - """ - Append a given actor, typically the Actor returned by - the pv.Plotter() when a given mes is added. - """ - + """Append a given actor, typically the Actor returned by the pv.Plotter() when a given mes is added.""" index = self._get_index_from_perforation( perforation_path ) if index == -1: print( "Cannot found the well to remove from path: ", perforation_path ) - return None + return self._wells[ index ].actor = tube_actor def get_actor( self, perforation_path: str ) -> pv.Actor | None: - """ - Retrieve the polyline linked to a given perforation path. - """ + """Retrieve the polyline linked to a given perforation path.""" index = self._get_index_from_perforation( perforation_path ) if index == -1: print( "Cannot found the well to remove from path: ", perforation_path ) @@ -104,14 +95,13 @@ def get_actor( self, perforation_path: str ) -> pv.Actor | None: return self._wells[ index ].actor def update( self, value: float ) -> None: + """Update the radius of the tubes.""" self.size = value for idx, m in enumerate( self._wells ): self._wells[ idx ].tube.copy_from( m.polyline.tube( radius=self.size, n_sides=50 ) ) def remove( self, perforation_path: str ) -> None: - """ - Clear all data stored in this class. - """ + """Clear all data stored in this class.""" index = self._get_index_from_perforation( perforation_path ) if index == -1: print( "Cannot found the well to remove from path: ", perforation_path ) @@ -119,9 +109,7 @@ def remove( self, perforation_path: str ) -> None: self._wells.remove( self._wells[ index ] ) def _get_index_from_perforation( self, perforation_path: str ) -> int: - """ - Retrieve the well associated to a given perforation, otherwise return -1. - """ + """Retrieve the well associated to a given perforation, otherwise return -1.""" index = -1 if len( self._wells ) == 0: return index @@ -133,5 +121,6 @@ def _get_index_from_perforation( self, perforation_path: str ) -> int: return index - def get_number_of_wells( self ): + def get_number_of_wells( self ) -> int: + """Get the number of wells in the viewer.""" return len( self._wells ) diff --git a/geos-trame/src/geos/trame/app/utils/__init__.py b/geos-trame/src/geos/trame/app/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/geos-trame/src/geos/trame/app/utils/dict_utils.py b/geos-trame/src/geos/trame/app/utils/dict_utils.py new file mode 100644 index 000000000..46df9a050 --- /dev/null +++ b/geos-trame/src/geos/trame/app/utils/dict_utils.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. +# SPDX-FileContributor: Kitware +from typing import Any + + +def iterate_nested_dict( iterable: dict | list, returned: str = "key" ) -> Any: + """Returns an iterator that returns all keys or values of a (nested) iterable. + + Arguments: + iterable: or + returned: "key" or "value" + + Returns: + - + """ + if isinstance( iterable, dict ): + for key, value in iterable.items(): + if key == "id" and not isinstance( value, ( dict, list ) ): + yield value + for ret in iterate_nested_dict( value, returned=returned ): + yield ret + elif isinstance( iterable, list ): + for el in iterable: + for ret in iterate_nested_dict( el, returned=returned ): + yield ret diff --git a/geos-trame/src/geos/trame/app/utils/file_utils.py b/geos-trame/src/geos/trame/app/utils/file_utils.py new file mode 100644 index 000000000..9ffcc5319 --- /dev/null +++ b/geos-trame/src/geos/trame/app/utils/file_utils.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. +# SPDX-FileContributor: Kitware +import os +import re +from io import StringIO +from typing import Any, List, TextIO +from lxml import etree as ElementTree # type: ignore[import-untyped] + + +def normalize_path( x: str ) -> str: + """Normalize the given path.""" + tmp = os.path.expanduser( x ) + tmp = os.path.abspath( tmp ) + if os.path.isfile( tmp ): + x = tmp + return x + + +def format_attribute( attribute_indent: str, ka: str, attribute_value: str ) -> str: + """Format xml attribute strings. + + Args: + attribute_indent (str): Attribute indent string + ka (str): Attribute name + attribute_value (str): Attribute value + + Returns: + str: Formatted attribute value + """ + # Make sure that a space follows commas + attribute_value = re.sub( r",\s*", ", ", attribute_value ) + + # Handle external brackets + attribute_value = re.sub( r"{\s*", "{ ", attribute_value ) + attribute_value = re.sub( r"\s*}", " }", attribute_value ) + + # Consolidate whitespace + attribute_value = re.sub( r"\s+", " ", attribute_value ) + + # Identify and split multi-line attributes + if re.match( r"\s*{\s*({[-+.,0-9a-zA-Z\s]*},?\s*)*\s*}", attribute_value ): + split_positions: List[ Any ] = [ match.end() for match in re.finditer( r"}\s*,", attribute_value ) ] + newline_indent = "\n%s" % ( " " * ( len( attribute_indent ) + len( ka ) + 4 ) ) + new_values = [] + for a, b in zip( [ None ] + split_positions, split_positions + [ None ] ): + new_values.append( attribute_value[ a:b ].strip() ) + if new_values: + attribute_value = newline_indent.join( new_values ) + + return attribute_value + + +def format_xml_level( + output: TextIO, + node: ElementTree.Element, + level: int, + indent: str = " " * 2, + block_separation_max_depth: int = 2, + modify_attribute_indent: bool = False, + sort_attributes: bool = False, + close_tag_newline: bool = False, + include_namespace: bool = False, +) -> None: + """Iteratively format the xml file. + + Args: + output (file): the output text file handle + node (lxml.etree.Element): the current xml element + level (int): the xml depth + indent (str): the xml indent style + block_separation_max_depth (int): the maximum depth to separate adjacent elements + modify_attribute_indent (bool): option to have flexible attribute indentation + sort_attributes (bool): option to sort attributes alphabetically + close_tag_newline (bool): option to place close tag on a separate line + include_namespace (bool): option to include the xml namespace in the output + """ + # Handle comments + if node.tag is ElementTree.Comment: + output.write( "\n%s" % ( indent * level, node.text ) ) + + else: + # Write opening line + opening_line = "\n%s<%s" % ( indent * level, node.tag ) + output.write( opening_line ) + + # Write attributes + if len( node.attrib ) > 0: + # Choose indentation + attribute_indent = "%s" % ( indent * ( level + 1 ) ) + if modify_attribute_indent: + attribute_indent = " " * ( len( opening_line ) ) + + # Get a copy of the attributes + attribute_dict = node.attrib + + # Sort attribute names + akeys = list( attribute_dict.keys() ) + if sort_attributes: + akeys = sorted( akeys ) + + # Format attributes + for ka in akeys: + # Avoid formatting mathpresso expressions + if not ( node.tag in [ "SymbolicFunction", "CompositeFunction" ] and ka == "expression" ): + attribute_dict[ ka ] = format_attribute( attribute_indent, ka, attribute_dict[ ka ] ) + + for ii in range( 0, len( akeys ) ): + k = akeys[ ii ] + if ( ii == 0 ) & modify_attribute_indent: + # TODO: attrib_ute_dict isn't define here which leads to an error + # output.write(' %s="%s"' % (k, attrib_ute_dict[k])) + pass + else: + output.write( '\n%s%s="%s"' % ( attribute_indent, k, attribute_dict[ k ] ) ) + + # Write children + if len( node ): + output.write( ">" ) + Nc = len( node ) + for ii, child in zip( range( Nc ), node ): + format_xml_level( + output, + child, + level + 1, + indent, + block_separation_max_depth, + modify_attribute_indent, + sort_attributes, + close_tag_newline, + include_namespace, + ) + + # Add space between blocks + if ( ( level < block_separation_max_depth ) + & ( ii < Nc - 1 ) + & ( child.tag is not ElementTree.Comment ) ): + output.write( "\n" ) + + # Write the end tag + output.write( "\n%s" % ( indent * level, node.tag ) ) + else: + if close_tag_newline: + output.write( "\n%s/>" % ( indent * level ) ) + else: + output.write( "/>" ) + + +def format_xml( + input_str: str, + indent_size: int = 2, + indent_style: bool = False, + block_separation_max_depth: int = 2, + alphabetize_attributes: bool = False, + close_style: bool = False, + namespace: bool = False, +) -> str: + """Script to format xml files. + + Args: + input_str (str): Input str + indent_size (int): Indent size + indent_style (bool): Style of indentation (0=fixed, 1=hanging) + block_separation_max_depth (int): Max depth to separate xml blocks + alphabetize_attributes (bool): Alphebitize attributes + close_style (bool): Style of close tag (0=same line, 1=new line) + namespace (bool): Insert this namespace in the xml description + """ + try: + root = ElementTree.fromstring( input_str ) + prologue_comments = [ tmp.text for tmp in root.itersiblings( preceding=True ) ] + epilog_comments = [ tmp.text for tmp in root.itersiblings() ] + + f = StringIO() + f.write( '\n' ) + + for comment in reversed( prologue_comments ): + f.write( "\n" % comment ) + + format_xml_level( + f, + root, + 0, + indent=" " * indent_size, + block_separation_max_depth=block_separation_max_depth, + modify_attribute_indent=indent_style, + sort_attributes=alphabetize_attributes, + close_tag_newline=close_style, + include_namespace=namespace, + ) + + for comment in epilog_comments: + f.write( "\n" % comment ) + f.write( "\n" ) + + return f.getvalue() + + except ElementTree.ParseError as err: + print( err.msg ) + raise Exception( "Failed to format xml file" ) from err diff --git a/geos-trame/src/geos/trame/app/utils/geos_utils.py b/geos-trame/src/geos/trame/app/utils/geos_utils.py new file mode 100644 index 000000000..2c2af348b --- /dev/null +++ b/geos-trame/src/geos/trame/app/utils/geos_utils.py @@ -0,0 +1,11 @@ +def group_name_ref_array_to_list( group_name_ref_array: str ) -> list[ str ] | None: + """Convert GEOS type groupNameRef_array to a list of string. + + Example: "{ test1, test2 }" becomes ["test1", "test2"] + """ + if ( not group_name_ref_array or not group_name_ref_array.strip().startswith( '{' ) + or not group_name_ref_array.strip().endswith( '}' ) ): + return None + + stripped = group_name_ref_array.strip().strip( '{}' ) + return [ item.strip() for item in stripped.split( ',' ) if item.strip() ] diff --git a/geos-trame/src/geos/trame/app/utils/pv_utils.py b/geos-trame/src/geos/trame/app/utils/pv_utils.py new file mode 100644 index 000000000..cb3af1330 --- /dev/null +++ b/geos-trame/src/geos/trame/app/utils/pv_utils.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. +# SPDX-FileContributor: Kitware +import pyvista as pv + + +def read_unstructured_grid( filename: str ) -> pv.UnstructuredGrid: + """Read an unstructured grid from a .vtu file.""" + return pv.read( filename ).cast_to_unstructured_grid() diff --git a/geos-trame/src/geos_trame/module/.gitignore b/geos-trame/src/geos/trame/module/.gitignore similarity index 100% rename from geos-trame/src/geos_trame/module/.gitignore rename to geos-trame/src/geos/trame/module/.gitignore diff --git a/geos-trame/src/geos_trame/module/__init__.py b/geos-trame/src/geos/trame/module/__init__.py similarity index 92% rename from geos-trame/src/geos_trame/module/__init__.py rename to geos-trame/src/geos/trame/module/__init__.py index 5e04d0c9b..1c705fceb 100644 --- a/geos-trame/src/geos_trame/module/__init__.py +++ b/geos-trame/src/geos/trame/module/__init__.py @@ -20,6 +20,6 @@ # Optional if you want to execute custom initialization at module load -def setup( app, **kwargs ): - """Method called at initialization with possibly some custom keyword arguments""" +def setup( app, **kwargs ): # noqa + """Method called at initialization with possibly some custom keyword arguments.""" pass diff --git a/geos-trame/src/geos_trame/schema_generated/README.md b/geos-trame/src/geos/trame/schema_generated/README.md similarity index 100% rename from geos-trame/src/geos_trame/schema_generated/README.md rename to geos-trame/src/geos/trame/schema_generated/README.md diff --git a/geos-trame/src/geos/trame/schema_generated/__init__.py b/geos-trame/src/geos/trame/schema_generated/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/geos-trame/src/geos_trame/schema_generated/old_schema_mod.py b/geos-trame/src/geos/trame/schema_generated/old_schema_mod.py similarity index 99% rename from geos-trame/src/geos_trame/schema_generated/old_schema_mod.py rename to geos-trame/src/geos/trame/schema_generated/old_schema_mod.py index 909a521a9..45919c59b 100644 --- a/geos-trame/src/geos_trame/schema_generated/old_schema_mod.py +++ b/geos-trame/src/geos/trame/schema_generated/old_schema_mod.py @@ -4,6 +4,8 @@ See: https://xsdata.readthedocs.io/ """ +# ruff: noqa + from __future__ import annotations from dataclasses import field diff --git a/geos-trame/src/geos_trame/schema_generated/schema_mod.py b/geos-trame/src/geos/trame/schema_generated/schema_mod.py similarity index 99% rename from geos-trame/src/geos_trame/schema_generated/schema_mod.py rename to geos-trame/src/geos/trame/schema_generated/schema_mod.py index 7c0105a3d..cc62720dc 100644 --- a/geos-trame/src/geos_trame/schema_generated/schema_mod.py +++ b/geos-trame/src/geos/trame/schema_generated/schema_mod.py @@ -4,6 +4,8 @@ See: https://xsdata.readthedocs.io/ """ +# ruff: noqa + from typing import List, Optional from pydantic import BaseModel, ConfigDict @@ -16302,8 +16304,3 @@ class Meta: "namespace": "", }, ) - - -class Problem( Problem ): - pass - model_config = ConfigDict( defer_build=True ) diff --git a/geos-trame/src/geos_trame/app/deck/file.py b/geos-trame/src/geos_trame/app/deck/file.py deleted file mode 100644 index 3840aa534..000000000 --- a/geos-trame/src/geos_trame/app/deck/file.py +++ /dev/null @@ -1,439 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. -# SPDX-FileContributor: Lionel Untereiner -import os -import re -import typing -from dataclasses import fields, is_dataclass -from io import StringIO -from typing import Any, Iterator, List, TextIO - -import typing_extensions -import typing_inspect -from lxml import etree as ElementTree # type: ignore[import-untyped] -from pydantic import BaseModel - -# from xsdata.formats.dataclass.context import XmlContext -# from xsdata.formats.dataclass.parsers import XmlParser -from xsdata.formats.dataclass.parsers.config import ParserConfig - -# from xsdata.formats.dataclass.serializers import DictEncoder -# from xsdata.formats.dataclass.serializers import XmlSerializer -from xsdata.formats.dataclass.serializers.config import SerializerConfig -from xsdata.utils import text -from xsdata_pydantic.bindings import DictEncoder, XmlContext, XmlParser, XmlSerializer - -# from geos_trame.app.deck.inspector import object_to_tree -from geos_trame.app.geosTrameException import GeosTrameException -from geos_trame.schema_generated.schema_mod import Problem - -from geos_trame.app.io.xml_parser import XMLParser - - -def get_origin( v: typing.Any ) -> typing.Any: - pydantic_generic_metadata = getattr( v, "__pydantic_generic_metadata__", None ) #: PydanticGenericMetadata | None - if pydantic_generic_metadata: - return pydantic_generic_metadata.get( "origin" ) - return typing_extensions.get_origin( v ) - - -def all_fields( c: type, already_checked ) -> list[ str ]: - resolved_hints = typing.get_type_hints( c ) - field_names = [ field.name for field in fields( c ) ] - resolved_field_types = { name: resolved_hints[ name ] for name in field_names } - - field_list = [] - for key in resolved_field_types: - current_type = resolved_field_types[ key ] - if typing_inspect.get_origin( current_type ) in ( list, typing.List ): - inner_type = typing_inspect.get_args( current_type )[ 0 ] - if inner_type not in already_checked: - already_checked.append( inner_type ) - field_list.extend( all_fields( inner_type, already_checked ) ) - if is_dataclass( current_type ) and current_type not in already_checked: - already_checked.append( current_type ) - field_list.extend( all_fields( current_type, already_checked ) ) - - # {"id": i, "name": f, "children": [], "hidden_children": []} - - return field_list - - -def required_fields( model: type[ BaseModel ], recursive: bool = False ) -> Iterator[ str ]: - for name, field in model.model_fields.items(): - print( name ) - if not field.is_required(): - continue - t = field.annotation - print( t ) - if recursive and isinstance( t, type ) and issubclass( t, BaseModel ): - yield from required_fields( t, recursive=True ) - else: - yield name - - -def is_pydantic_model( obj ): - try: - return issubclass( obj, BaseModel ) - except TypeError: - return False - - -def show_hierarchy( Model: BaseModel, processed_types: set, indent: int = 0 ): - print( type( Model ).__name__ ) - if type( Model ).__name__ not in processed_types: - processed_types.add( type( Model ).__name__ ) - print( processed_types ) - for k, v in Model.model_fields.items(): - print( f'{" "*indent}{k}: ' - f"type={v.annotation}, " - f"required={v.is_required()}" ) - if is_pydantic_model( typing.get_args( v.annotation )[ 0 ] ): - # print("plop") - show_hierarchy( typing.get_args( v.annotation )[ 0 ], processed_types, indent + 2 ) - - -def normalize_path( x ): - tmp = os.path.expanduser( x ) - tmp = os.path.abspath( tmp ) - if os.path.isfile( tmp ): - x = tmp - return x - - -class DeckFile( object ): - """ - Holds the information of a deck file. - Can be empty. - """ - - def __init__( self, filename: str, **kwargs ) -> None: - """ - Constructor. - Input: - filename: file name of the deck file - """ - super( DeckFile, self ).__init__( **kwargs ) - - self.root_node = None - self.filename = normalize_path( filename ) - if self.filename: - self.open_deck_file( self.filename ) - self.original_text = "" - self.changed = False - - self.path = os.path.dirname( self.filename ) - - def open_deck_file( self, filename: str ) -> None: - """ - Opens a file and parses it. - Input: - filename: file name of the input file - Signals: - input_file_changed: On success - Raises: - GeosTrameException: On invalid input file - """ - - self.changed = False - self.root_node = None - - # Do some basic checks on the filename to make sure - # it is probably a real input file since the GetPot - # parser doesn't do any checks. - if not os.path.exists( filename ): - msg = "Input file %s does not exist" % filename - raise GeosTrameException( msg ) - - if not os.path.isfile( filename ): - msg = "Input file %s is not a file" % filename - raise GeosTrameException( msg ) - - if not filename.endswith( ".xml" ): - msg = "Input file %s does not have the proper extension" % filename - raise GeosTrameException( msg ) - - self.xml_parser = XMLParser( filename=filename ) - self.xml_parser.build() - simulation_deck = self.xml_parser.get_simulation_deck() - - context = XmlContext( - element_name_generator=text.pascal_case, - attribute_name_generator=text.camel_case, - ) - parser = XmlParser( context=context, config=ParserConfig( - ) ) # fail_on_unknown_properties=True, fail_on_unknown_attributes=True, fail_on_converter_warnings=True)) - try: - self.problem = parser.parse( simulation_deck, Problem ) - except ElementTree.XMLSyntaxError as e: - msg = "Failed to parse input file %s:\n%s\n" % ( filename, e ) - raise GeosTrameException( msg ) - - encoder = DictEncoder( context=context, config=SerializerConfig( indent=" " ) ) - self.pb_dict = { "Problem": encoder.encode( self.problem ) } - self.inspect_tree = build_inspect_tree( encoder.encode( self.problem ) ) - - def to_str( self ) -> str: - config = SerializerConfig( indent=" ", xml_declaration=False ) - context = XmlContext( - element_name_generator=text.pascal_case, - attribute_name_generator=text.camel_case, - ) - serializer = XmlSerializer( context=context, config=config ) - return serializer.render( self.problem ) - - -def build_inspect_tree( obj, *, dict_factory=dict ) -> dict: - """Return the fields of a dataclass instance as a new dictionary mapping - field names to field values. - - Example usage:: - - @dataclass - class C: - x: int - y: int - - c = C(1, 2) - assert asdict(c) == {'x': 1, 'y': 2} - - If given, 'dict_factory' will be used instead of built-in dict. - The function applies recursively to field values that are - dataclass instances. This will also look into built-in containers: - tuples, lists, and dicts. Other objects are copied with 'copy.deepcopy()'. - """ - # if not _is_dataclass_instance(obj): - # raise TypeError("asdict() should be called on dataclass instances") - - return _build_inspect_tree_inner( "Problem", obj, [] ) - - -def _build_inspect_tree_inner( key, obj, path ) -> dict: - sub_node = dict() - if "name" in obj: - sub_node[ "title" ] = obj[ "name" ] - else: - sub_node[ "title" ] = key - # sub_node["id"] = randrange(150) - sub_node[ "children" ] = list() - # sub_node["hidden_children"] = list() - sub_node[ "is_drawable" ] = key in [ - "VTKMesh", - "InternalMesh", - "InternalWell", - "VTKWell", - "Perforation", - ] - sub_node[ "drawn" ] = False - - for key, value in obj.items(): - - if isinstance( value, list ): - for idx, item in enumerate( value ): - if isinstance( item, dict ): - more_results = _build_inspect_tree_inner( key, item, path + [ key ] + [ idx ] ) - # for another_result in more_results: - sub_node[ "children" ].append( more_results ) - - # sub_node["path"] = path + [sub_node["name"]] - sub_node[ "id" ] = "Problem/" + "/".join( map( str, path ) ) - - return sub_node - - -def format_attribute( attribute_indent: str, ka: str, attribute_value: str ) -> str: - """Format xml attribute strings - - Args: - attribute_indent (str): Attribute indent string - ka (str): Attribute name - attribute_value (str): Attribute value - - Returns: - str: Formatted attribute value - """ - # Make sure that a space follows commas - attribute_value = re.sub( r",\s*", ", ", attribute_value ) - - # Handle external brackets - attribute_value = re.sub( r"{\s*", "{ ", attribute_value ) - attribute_value = re.sub( r"\s*}", " }", attribute_value ) - - # Consolidate whitespace - attribute_value = re.sub( r"\s+", " ", attribute_value ) - - # Identify and split multi-line attributes - if re.match( r"\s*{\s*({[-+.,0-9a-zA-Z\s]*},?\s*)*\s*}", attribute_value ): - split_positions: List[ Any ] = [ match.end() for match in re.finditer( r"}\s*,", attribute_value ) ] - newline_indent = "\n%s" % ( " " * ( len( attribute_indent ) + len( ka ) + 4 ) ) - new_values = [] - for a, b in zip( [ None ] + split_positions, split_positions + [ None ] ): - new_values.append( attribute_value[ a:b ].strip() ) - if new_values: - attribute_value = newline_indent.join( new_values ) - - return attribute_value - - -def format_xml_level( - output: TextIO, - node: ElementTree.Element, - level: int, - indent: str = " " * 2, - block_separation_max_depth: int = 2, - modify_attribute_indent: bool = False, - sort_attributes: bool = False, - close_tag_newline: bool = False, - include_namespace: bool = False, -) -> None: - """Iteratively format the xml file - - Args: - output (file): the output text file handle - node (lxml.etree.Element): the current xml element - level (int): the xml depth - indent (str): the xml indent style - block_separation_max_depth (int): the maximum depth to separate adjacent elements - modify_attribute_indent (bool): option to have flexible attribute indentation - sort_attributes (bool): option to sort attributes alphabetically - close_tag_newline (bool): option to place close tag on a separate line - include_namespace (bool): option to include the xml namespace in the output - """ - - # Handle comments - if node.tag is ElementTree.Comment: - output.write( "\n%s" % ( indent * level, node.text ) ) - - else: - # Write opening line - opening_line = "\n%s<%s" % ( indent * level, node.tag ) - output.write( opening_line ) - - # Write attributes - if len( node.attrib ) > 0: - # Choose indentation - attribute_indent = "%s" % ( indent * ( level + 1 ) ) - if modify_attribute_indent: - attribute_indent = " " * ( len( opening_line ) ) - - # Get a copy of the attributes - attribute_dict = {} - attribute_dict = node.attrib - - # Sort attribute names - akeys = list( attribute_dict.keys() ) - if sort_attributes: - akeys = sorted( akeys ) - - # Format attributes - for ka in akeys: - # Avoid formatting mathpresso expressions - if not ( node.tag in [ "SymbolicFunction", "CompositeFunction" ] and ka == "expression" ): - attribute_dict[ ka ] = format_attribute( attribute_indent, ka, attribute_dict[ ka ] ) - - for ii in range( 0, len( akeys ) ): - k = akeys[ ii ] - if ( ii == 0 ) & modify_attribute_indent: - # TODO: attrib_ute_dict isn't define here which leads to an error - # output.write(' %s="%s"' % (k, attrib_ute_dict[k])) - pass - else: - output.write( '\n%s%s="%s"' % ( attribute_indent, k, attribute_dict[ k ] ) ) - - # Write children - if len( node ): - output.write( ">" ) - Nc = len( node ) - for ii, child in zip( range( Nc ), node ): - format_xml_level( - output, - child, - level + 1, - indent, - block_separation_max_depth, - modify_attribute_indent, - sort_attributes, - close_tag_newline, - include_namespace, - ) - - # Add space between blocks - if ( ( level < block_separation_max_depth ) - & ( ii < Nc - 1 ) - & ( child.tag is not ElementTree.Comment ) ): - output.write( "\n" ) - - # Write the end tag - output.write( "\n%s" % ( indent * level, node.tag ) ) - else: - if close_tag_newline: - output.write( "\n%s/>" % ( indent * level ) ) - else: - output.write( "/>" ) - - -def format_xml( - input: str, - indent_size: int = 2, - indent_style: bool = False, - block_separation_max_depth: int = 2, - alphebitize_attributes: bool = False, - close_style: bool = False, - namespace: bool = False, -) -> None: - """Script to format xml files - - Args: - input (str): Input str - indent_size (int): Indent size - indent_style (bool): Style of indentation (0=fixed, 1=hanging) - block_separation_max_depth (int): Max depth to separate xml blocks - alphebitize_attributes (bool): Alphebitize attributes - close_style (bool): Style of close tag (0=same line, 1=new line) - namespace (bool): Insert this namespace in the xml description - """ - try: - root = ElementTree.fromstring( input ) - # root = tree.getroot() - prologue_comments = [ tmp.text for tmp in root.itersiblings( preceding=True ) ] - epilog_comments = [ tmp.text for tmp in root.itersiblings() ] - - f = StringIO() - f.write( '\n' ) - - for comment in reversed( prologue_comments ): - f.write( "\n" % ( comment ) ) - - format_xml_level( - f, - root, - 0, - indent=" " * indent_size, - block_separation_max_depth=block_separation_max_depth, - modify_attribute_indent=indent_style, - sort_attributes=alphebitize_attributes, - close_tag_newline=close_style, - include_namespace=namespace, - ) - - for comment in epilog_comments: - f.write( "\n" % ( comment ) ) - f.write( "\n" ) - - return f.getvalue() - - except ElementTree.ParseError as err: - print( "\nCould not load file: %s" % ( f ) ) - print( err.msg ) - raise Exception( "\nCheck input file!" ) - - -if __name__ == "__main__": - import sys - - if len( sys.argv ) < 2: - print( "Need an input file as argument" ) - exit( 1 ) - filename = sys.argv[ 1 ] - deck_file = DeckFile( filename ) - print( deck_file.root_fields ) diff --git a/geos-trame/src/geos_trame/app/ui/inspector.py b/geos-trame/src/geos_trame/app/ui/inspector.py deleted file mode 100644 index 309ba0994..000000000 --- a/geos-trame/src/geos_trame/app/ui/inspector.py +++ /dev/null @@ -1,307 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. -# SPDX-FileContributor: Lionel Untereiner -from dataclasses import dataclass -from enum import Enum - -import yaml -from pydantic import BaseModel -from trame.widgets import vuetify3 as vuetify -from trame.widgets import html -from trame_simput import get_simput_manager -from typing import Any - - -class Renderable( Enum ): - VTKMESH = "VTKMesh" - INTERNALMESH = "InternalMesh" - INTERNALWELL = "InternalWell" - VTKWELL = "VTKWell" - PERFORATION = "Perforation" - - -# Pure pydantic version -# -# class TreeNode(BaseModel): -# id: str -# name: str -# is_drawable: bool -# drawn: bool -# children: list['TreeNode'] -# hidden_children: list['TreeNode'] - -# def get_node(obj, node_id, path): -# children = [] -# for name, info in obj.model_fields.items(): -# if name in obj.model_fields_set: -# print(type(info)) -# print(name, "-", info.annotation, " - ", get_origin(info.annotation), get_args(info.annotation)[0]) -# metadata = getattr(info, "xsdata_metadata", None) or {} -# print(metadata["name"]) -# if get_origin(info.annotation) is list: -# attr= getattr(obj, name) -# print(attr) -# for idx, item in enumerate(attr): -# children.append(get_node(item, name, path + [name] + [idx])) - -# return TreeNode( -# id = "Problem/" + "/".join(map(str, path)), -# name = "metadata", -# children = children, -# hidden_children = [], -# is_drawable = node_id in Renderable, -# drawn = False, -# ) - - -@dataclass -class TreeNode: - id: str - title: str - children: list - hidden_children: list - is_drawable: bool - drawn: bool - - @property - def json( self ) -> dict: - if self.children: - return dict( - id=self.id, - title=self.title, - is_drawable=self.is_drawable, - drawn=self.drawn, - children=[ c.json for c in self.children ], - hidden_children=[ c.json for c in self.hidden_children ], - ) - return dict( - id=self.id, - title=self.title, - is_drawable=self.is_drawable, - drawn=self.drawn, - children=None, - hidden_children=[], - ) - - -def get_node_dict( obj, node_id, path ): - children = [] - for key, value in obj.items(): - # todo look isinstance(value, dict): - if isinstance( value, list ): - for idx, item in enumerate( value ): - if isinstance( item, dict ): - children.append( get_node_dict( item, key, path + [ key ] + [ idx ] ) ) - - node_name = node_id - if "name" in obj: - node_name = obj[ "name" ] - - return TreeNode( - id="Problem/" + "/".join( map( str, path ) ), - title=node_name, - children=children if len( children ) else [], - hidden_children=[], - is_drawable=node_id in ( k.value for k in Renderable ), - drawn=False, - ) - - -def object_to_tree( obj: dict ) -> dict: - return get_node_dict( obj, "Problem", [] ).json - - -def dump( item ): - match item: - case BaseModel() as model: - subitems: dict[ str, Any ] = dict() - model.model_fields - - for field, value in model: - - if isinstance( value, str ): - subitems[ field ] = value - continue - - return subitems - case list() | tuple() | set(): # pyright: ignore - # Pyright finds this disgusting; this passes `mypy` though. ` # type: - # ignore` would fail `mypy` is it'd be unused (because there's nothing to - # ignore because `mypy` is content) - # return type(container)( # pyright: ignore - # _dump(i) for i in container # pyright: ignore - # ) - pass - case dict(): - # return { - # k: _dump(v) - # for k, v in item.items() # pyright: ignore[reportUnknownVariableType] - # } - pass - case _: - return item - - -def iterate_nested_dict( iterable, returned="key" ): - """Returns an iterator that returns all keys or values - of a (nested) iterable. - - Arguments: - - iterable: or - - returned: "key" or "value" - - Returns: - - - """ - - if isinstance( iterable, dict ): - for key, value in iterable.items(): - if key == "id": - if not ( isinstance( value, dict ) or isinstance( value, list ) ): - yield value - # else: - # raise ValueError("'returned' keyword only accepts 'key' or 'value'.") - for ret in iterate_nested_dict( value, returned=returned ): - yield ret - elif isinstance( iterable, list ): - for el in iterable: - for ret in iterate_nested_dict( el, returned=returned ): - yield ret - - -vuetify.enable_lab() - - -class DeckInspector( vuetify.VTreeview ): - - def __init__( self, listen_to_active=True, source=None, **kwargs ): - super().__init__( - # data - items=( "deck_tree", ), - item_value="id", - **{ - # style - "hoverable": True, - "max_width": 500, - "rounded": True, - # activation logic - "activatable": True, - "activated": ( "active_ids", ), - "active_strategy": "single-independent", - "update_activated": ( self.change_current_id, "$event" ), - # selection logic - "selectable": False, - **kwargs, - }, - ) - self.tree = source - self._source = None - self.listen_to_active = listen_to_active - - self.state.object_state = [ "", False ] - - # register used types from Problem - self.simput_types = [] - - self.simput_manager = get_simput_manager( id=self.state.sm_id ) - - if source.input_file is None: - return - - self.set_source( source.input_file.problem ) - - def on_change( topic, ids=None, **kwargs ): - if topic == "changed": - for obj_id in ids: - proxy = self.simput_manager.proxymanager.get( obj_id ) - self.tree.decode( obj_id ) - for prop in proxy.edited_property_names: - self.tree.update( obj_id, prop, proxy.get_property( prop ) ) - - self.simput_manager.proxymanager.on( on_change ) - - with self: - with vuetify.Template( v_slot_append="{ item }" ): - vuetify.VCheckboxBtn( v_if="item.is_drawable", - focused=True, - dense=True, - hide_details=True, - icon=True, - false_icon="mdi-eye-off", - true_icon="mdi-eye", - update_modelValue=( self.to_draw_change, "[ item, item.id, $event ] " ) ) - - def to_draw_change( self, item, item_id, drawn ): - self.state.object_state = [ item_id, drawn ] - - @property - def source( self ): - return self._source - - # TODO - # v should be a proxy like the one in paraview simple - # maybe it can be Any of schema_mod (e.g. Problem) - def set_source( self, v ): - - # TODO replace this snippet - from xsdata.formats.dataclass.serializers.config import SerializerConfig - from xsdata.utils import text - from xsdata_pydantic.bindings import DictEncoder, XmlContext - - context = XmlContext( - element_name_generator=text.pascal_case, - attribute_name_generator=text.camel_case, - ) - - encoder = DictEncoder( context=context, config=SerializerConfig( indent=" " ) ) - self._source = encoder.encode( v ) - # with this one by passing v as Problem - # self._source = v - - if v is None: - self.state.deck_tree = [] - else: - self.state.deck_tree = object_to_tree( self._source ).get( "children", [] ) - - for v in iterate_nested_dict( self.state.deck_tree ): - - active_block = self.tree.decode( v ) - # active_name = None - - # if hasattr(active_block, "name"): - # active_name = active_block.name - - simput_type = type( active_block ).__name__ - - test = dump( active_block ) - - if test: - params_dict = {} - for key, value in test.items(): - params_dict[ key ] = { - # "initial": str(v), - "type": "string", - } - - self.simput_types.append( simput_type ) - yaml_str = yaml.dump( { simput_type: params_dict }, sort_keys=False ) - - self.simput_manager.load_model( yaml_content=yaml_str ) - - debug = self.simput_manager.proxymanager.create( simput_type, proxy_id=v ) - - for key, value in test.items(): - debug.set_property( key, getattr( active_block, key ) ) - debug.commit() - - def change_current_id( self, item_id=None ): - """ - Change the current id of the tree. - This function is called when the user click on the tree. - """ - if item_id is None: - # Silently ignore, it could occurs is the user click on the tree - # and this item is already selected - return - - self.state.active_id = item_id diff --git a/geos-trame/src/geos_trame/app/ui/viewer/viewer.py b/geos-trame/src/geos_trame/app/ui/viewer/viewer.py deleted file mode 100644 index b5d11acdf..000000000 --- a/geos-trame/src/geos_trame/app/ui/viewer/viewer.py +++ /dev/null @@ -1,368 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. -# SPDX-FileContributor: Lucas Givord - Kitware -import pyvista as pv -from pyvista.trame.ui import plotter_ui -from trame.widgets import vuetify3 as vuetify -from trame.widgets import html - -from geos_trame.schema_generated.schema_mod import ( - Vtkmesh, - Vtkwell, - Perforation, - InternalWell, -) - -import geos_trame.app.ui.viewer.regionViewer as RegionViewer -import geos_trame.app.ui.viewer.wellViewer as WellViewer -import geos_trame.app.ui.viewer.perforationViewer as PerforationViewer -from geos_trame.app.geosTrameException import GeosTrameException - -import numpy as np -from typing import Type, Any - -pv.OFF_SCREEN = True - - -class DeckViewer( vuetify.VCard ): - """ - Deck representing the 3D View using PyVista. - - This view can show: - - Vtkmesh, - - Vtkwell, - - Perforation, - - InternalWell - - Everything is handle in the method 'update_viewer()' which is trigger when the - 'state.object_state' changed (see DeckTree). - - This View handle widgets, such as clip widget or slider to control Wells or - Perforation settings. - """ - - def __init__( self, source, **kwargs ): - super().__init__( **kwargs ) - - self._source = source - self._pl = pv.Plotter() - - self.CUT_PLANE = "on_cut_plane_visibility_change" - self.ZAMPLIFICATION = "_z_amplification" - self.server.state[ self.CUT_PLANE ] = True - self.server.state[ self.ZAMPLIFICATION ] = 1 - - self.region_engine = RegionViewer.RegionViewer() - self.well_engine = WellViewer.WellViewer( 5, 5 ) - self._perforations: dict[ str, PerforationViewer.PerforationViewer ] = dict() - - self.state.change( "object_state" )( self.update_viewer ) - - with self: - vuetify.VCardTitle( "3D View" ) - view = plotter_ui( - self._pl, - add_menu_items=self.rendering_menu_extra_items, - style="position: absolute;", - ) - self.ctrl.view_update = view.update - - @property - def plotter( self ): - return self._pl - - @property - def source( self ): - return self._source - - def rendering_menu_extra_items( self ): - """ - Extend the default pyvista menu with custom button. - - For now, adding a button to show/hide all widgets. - """ - self.state.change( self.CUT_PLANE )( self._on_clip_visibility_change ) - vuetify.VDivider( vertical=True, classes="mr-1" ) - with vuetify.VTooltip( location="bottom" ): - with vuetify.Template( v_slot_activator=( "{ props }", ) ): - with html.Div( v_bind=( "props", ) ): - vuetify.VCheckbox( - v_model=( self.CUT_PLANE, True ), - icon=True, - true_icon="mdi-eye", - false_icon="mdi-eye-off", - dense=True, - hide_details=True, - ) - html.Span( "Show/Hide widgets" ) - - def update_viewer( self, object_state: list[ str, bool ], **kwargs ) -> None: - """ - Add from path the dataset given by the user. - Supported data type is: Vtkwell, Vtkmesh, InternalWell, Perforation. - - object_state : array used to store path to the data and if we want to show it or not. - """ - path = object_state[ 0 ] - show_obj = object_state[ 1 ] - - if path == "": - return - active_block = self.source.decode( path ) - - if isinstance( active_block, Vtkmesh ): - self._update_vtkmesh( active_block, show_obj ) - - if isinstance( active_block, Vtkwell ): - if self.region_engine.input.number_of_cells == 0 and show_obj: - - self.ctrl.on_add_warning( - "Can't display " + active_block.name, - "Please display the mesh before creating a well.", - ) - return - - self._update_vtkwell( active_block, path, show_obj ) - - if isinstance( active_block, InternalWell ): - if self.region_engine.input.number_of_cells == 0 and show_obj: - self.ctrl.on_add_warning( - "Can't display " + active_block.name, - "Please display the mesh before creating a well", - ) - return - - self._update_internalwell( active_block, path, show_obj ) - - if isinstance( active_block, Perforation ): - if self.well_engine.get_number_of_wells() == 0 and show_obj: - self.ctrl.on_add_warning( - "Can't display " + active_block.name, - "Please display a well before creating a perforation", - ) - return - self._update_perforation( active_block, show_obj, path ) - - def _on_clip_visibility_change( self, **kwargs ): - """Toggle cut plane visibility for all actors. - - Parameters - ---------- - **kwargs : dict, optional - Unused keyword arguments. - - """ - show_widgets = kwargs[ self.CUT_PLANE ] - if show_widgets: - self._setup_slider() - else: - self._remove_slider() - - if self.plotter.plane_widgets: - widgets = self.plotter.plane_widgets - widgets[ 0 ].SetEnabled( show_widgets ) - self.plotter.render() - - def _setup_slider( self ) -> None: - """ - Create slider to control in the gui well parameters. - """ - - wells_radius = self._get_tube_size() - self.plotter.add_slider_widget( - self._on_change_tube_size, - [ 1, 20 ], - title="Wells radius", - pointa=( 0.02, 0.12 ), - pointb=( 0.30, 0.12 ), - title_opacity=0.5, - title_color="black", - title_height=0.02, - value=wells_radius, - ) - - perforation_radius = self._get_perforation_size() - self.plotter.add_slider_widget( - self._on_change_perforation_size, - [ 1, 50 ], - title="Perforation radius", - title_opacity=0.5, - pointa=( 0.02, 0.25 ), - pointb=( 0.30, 0.25 ), - title_color="black", - title_height=0.02, - value=perforation_radius, - ) - - def _remove_slider( self ) -> None: - """ - Create slider to control in the gui well parameters. - """ - self.plotter.clear_slider_widgets() - - def _on_change_tube_size( self, value ) -> None: - self.well_engine.update( value ) - - def _get_tube_size( self ) -> float: - return self.well_engine.get_tube_size() - - def _on_change_perforation_size( self, value ) -> None: - for key, perforation in self._perforations.items(): - perforation.update_perforation_radius( value ) - - def _get_perforation_size( self ) -> float: - if len( self._perforations ) <= 0: - return 5 - - for key, perforation in self._perforations.items(): - return perforation.get_perforation_size() - - def _update_internalwell( self, well: InternalWell, path: str, show: bool ) -> None: - """ - Used to control the visibility of the InternalWell. - This method will create the mesh if it doesn't exist. - """ - if not show: - self.plotter.remove_actor( self.well_engine.get_actor( path ) ) - self.well_engine.remove( path ) - return - - points = self.__parse_polyline_property( well.polyline_node_coords, dtype=float ) - connectivity = self.__parse_polyline_property( well.polyline_segment_conn, dtype=int ) - connectivity = connectivity.flatten() - - sorted_points = [] - for id in connectivity: - sorted_points.append( points[ id ] ) - - well_polydata = pv.MultipleLines( sorted_points ) - index = self.well_engine.add_mesh( well_polydata, path ) - - tube_actor = self.plotter.add_mesh( self.well_engine.get_tube( index ) ) - self.well_engine.append_actor( path, tube_actor ) - - self.server.controller.view_update() - - def _update_vtkwell( self, well: Vtkwell, path: str, show: bool ) -> None: - """ - Used to control the visibility of the Vtkwell. - This method will create the mesh if it doesn't exist. - """ - if not show: - self.plotter.remove_actor( self.well_engine.get_actor( path ) ) - self.well_engine.remove( path ) - return - - well_polydata = pv.PolyData.SafeDownCast( pv.read( self.source.get_abs_path( well.file ) ) ) - index = self.well_engine.add_mesh( well_polydata, path ) - - tube_actor = self.plotter.add_mesh( self.well_engine.get_tube( index ) ) - self.well_engine.append_actor( path, tube_actor ) - - self.server.controller.view_update() - - def _update_vtkmesh( self, mesh: Vtkmesh, show: bool ) -> None: - """ - Used to control the visibility of the Vtkmesh. - This method will create the mesh if it doesn't exist. - - Additionally, a clip filter will be added. - """ - - if not show: - self.region_engine.reset() - self.plotter.clear_plane_widgets() - self.plotter.remove_actor( self._clip_mesh ) - return - - unsctructured_grid = pv.UnstructuredGrid.SafeDownCast( pv.read( self.source.get_abs_path( mesh.file ) ) ) - self.region_engine.add_mesh( unsctructured_grid ) - active_scalar = self.region_engine.input.active_scalars_name - self._clip_mesh = self.plotter.add_mesh_clip_plane( - self.region_engine.input, - origin=self.region_engine.input.center, - normal=[ -1, 0, 0 ], - crinkle=True, - show_edges=False, - cmap="glasbey_bw", - scalars=active_scalar, - ) - - self.server.controller.view_update() - - def _update_perforation( self, perforation: Perforation, show: bool, path: str ) -> None: - """ - Generate VTK dataset from a perforation. - """ - - if not show: - if path in self._perforations: - self._remove_perforation( path ) - return - - distance_from_head = float( perforation.distance_from_head ) - self._add_perforation( distance_from_head, path ) - - def _remove_perforation( self, path: str ) -> None: - """ - Remove all actor related to the given path and clean the stored perforation - """ - saved_perforation: PerforationViewer.PerforationViewer = self._perforations[ path ] - self.plotter.remove_actor( saved_perforation.extracted_cell ) - self.plotter.remove_actor( saved_perforation.perforation_actor ) - saved_perforation.reset() - - def _add_perforation( self, distance_from_head: float, path: str ) -> None: - """ - Generate perforation dataset based on the distance from the top of a polyline - """ - - polyline: pv.PolyData = self.well_engine.get_mesh( path ) - if polyline is None: - return - - point = polyline.points[ 0 ] - point_offsetted = [ - point[ 0 ], - point[ 1 ], - point[ 2 ] - distance_from_head, - ] - - center = [ point[ 0 ], point[ 1 ], point[ 2 ] - float( distance_from_head ) ] - sphere = pv.Sphere( radius=5, center=center ) - - perforation_actor = self.plotter.add_mesh( sphere ) - saved_perforation = PerforationViewer.PerforationViewer( sphere, center, 5, perforation_actor ) - - id = self.region_engine.input.find_closest_cell( point_offsetted ) - cell = self.region_engine.input.extract_cells( [ id ] ) - cell_actor = self.plotter.add_mesh( cell ) - saved_perforation.add_extracted_cell( cell_actor ) - - self._perforations[ path ] = saved_perforation - - def __parse_polyline_property( self, property: str, dtype: Type[ Any ] ) -> np.ndarray[ Any ]: - """ - Internal method used to parse and convert a property, such as polyline_node_coords, from an InternalWell. - This string always follow this for : - "{ { 800, 1450, 395.646 }, { 800, 1450, -554.354 } }" - """ - try: - nodes_str = property.split( "}, {" ) - points = [] - for i in range( 0, len( nodes_str ) ): - - nodes_str[ i ] = nodes_str[ i ].replace( " ", "" ) - nodes_str[ i ] = nodes_str[ i ].replace( "{", "" ) - nodes_str[ i ] = nodes_str[ i ].replace( "}", "" ) - - point = np.array( nodes_str[ i ].split( "," ), dtype=dtype ) - - points.append( point ) - - return np.array( points, dtype=dtype ) - except ValueError: - raise GeosTrameException( - "cannot be able to convert the property into a numeric array: ", - ValueError, - ) diff --git a/geos-trame/src/geos_trame/widgets/geos_trame.py b/geos-trame/src/geos_trame/widgets/geos_trame.py deleted file mode 100644 index 314d4b1fc..000000000 --- a/geos-trame/src/geos_trame/widgets/geos_trame.py +++ /dev/null @@ -1,18 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. -# SPDX-FileContributor: Lionel Untereiner -from trame_client.widgets.core import AbstractElement - -from .. import module - -__all__ = [ - "Editor", -] - - -class HtmlElement( AbstractElement ): - - def __init__( self, _elem_name, children=None, **kwargs ): - super().__init__( _elem_name, children, **kwargs ) - if self.server: - self.server.enable_module( module ) diff --git a/geos-trame/tests/conftest.py b/geos-trame/tests/conftest.py index 90183df81..f38de5e35 100644 --- a/geos-trame/tests/conftest.py +++ b/geos-trame/tests/conftest.py @@ -5,6 +5,8 @@ from pathlib import Path from trame_client.utils.testing import FixtureHelper +# ruff: noqa + ROOT_PATH = Path( __file__ ).parent.parent.absolute() print( ROOT_PATH ) HELPER = FixtureHelper( ROOT_PATH ) @@ -18,7 +20,7 @@ def baseline_image(): @pytest.fixture -def server( xprocess, server_path ): +def server( xprocess, server_path: str ): name, Starter, Monitor = HELPER.get_xprocess_args( server_path ) Starter.timeout = 10 diff --git a/geos-trame/tests/test_file_handling.py b/geos-trame/tests/test_file_handling.py index 2cbf4c2cc..98a4bb6e5 100644 --- a/geos-trame/tests/test_file_handling.py +++ b/geos-trame/tests/test_file_handling.py @@ -1,17 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Kitware +from _pytest.capture import CaptureFixture from trame.app import get_server from trame_client.utils.testing import enable_testing -from geos_trame.app.core import GeosTrame +from geos.trame.app.core import GeosTrame -def test_unsupported_file( capsys ): - +def test_unsupported_file( capsys: CaptureFixture[ str ] ) -> None: + """Test unsupported file.""" server = enable_testing( get_server( client_type="vue3" ), "message" ) file_name = "tests/data/acous3D/acous3D_vtu.xml" GeosTrame( server, file_name ) captured = capsys.readouterr() - assert captured.err == "The file tests/data/acous3D/acous3D_vtu.xml cannot be parsed.\n" + assert ( captured.err == "The file tests/data/acous3D/acous3D_vtu.xml cannot be parsed.\n" ) diff --git a/geos-trame/tests/test_import.py b/geos-trame/tests/test_import.py index a3e3070eb..2f8e4a491 100644 --- a/geos-trame/tests/test_import.py +++ b/geos-trame/tests/test_import.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Lionel Untereiner -def test_import(): - from geos_trame.app.core import GeosTrame # noqa: F401 +def test_import() -> None: + """Test GeosTrame import.""" + from geos.trame.app.core import GeosTrame # noqa: F401 diff --git a/geos-trame/tests/test_load_and_visualize_synthetic_dataset.py b/geos-trame/tests/test_load_and_visualize_synthetic_dataset.py index 490f7bee2..ace0e20ff 100644 --- a/geos-trame/tests/test_load_and_visualize_synthetic_dataset.py +++ b/geos-trame/tests/test_load_and_visualize_synthetic_dataset.py @@ -5,12 +5,14 @@ import os from trame.app import get_server -from geos_trame.app.core import GeosTrame +from geos.trame.app.core import GeosTrame from seleniumbase import SB from selenium.webdriver.common.action_chains import ActionChains from selenium.webdriver.common.by import By +# ruff: noqa + @pytest.mark.skip( "Test to fix" ) @pytest.mark.parametrize( "server_path", [ "tests/utils/start_geos_trame_for_testing.py" ] ) diff --git a/geos-trame/tests/test_properties_checker.py b/geos-trame/tests/test_properties_checker.py new file mode 100644 index 000000000..b8bc2ac12 --- /dev/null +++ b/geos-trame/tests/test_properties_checker.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. +# SPDX-FileContributor: Kitware +# ruff: noqa +from pathlib import Path + +from trame_server import Server +from trame_server.state import State +from trame_vuetify.ui.vuetify3 import VAppLayout + +from geos.trame.app.core import GeosTrame +from geos.trame.app.data_types.field_status import FieldStatus +from tests.trame_fixtures import trame_server_layout, trame_state + + +def test_properties_checker( trame_server_layout: tuple[ Server, VAppLayout ], trame_state: State ) -> None: + """Test properties checker.""" + root_path = Path( __file__ ).parent.absolute().__str__() + file_name = root_path + "/data/singlePhaseFlow/FieldCaseTutorial3_smoke.xml" + + geos_trame = GeosTrame( trame_server_layout[ 0 ], file_name ) + + field = trame_state.deck_tree[ 4 ][ "children" ][ 0 ] + assert field[ "valid" ] == FieldStatus.UNCHECKED.value + + geos_trame.simput_manager.proxymanager.get( "Problem/Mesh/0/VTKMesh/0" )[ "region_attribute" ] = "invalid" + geos_trame.properties_checker.check_fields() + assert field[ "valid" ] == FieldStatus.INVALID.value + + geos_trame.simput_manager.proxymanager.get( "Problem/Mesh/0/VTKMesh/0" )[ "region_attribute" ] = "attribute" + geos_trame.properties_checker.check_fields() + assert field[ "valid" ] == FieldStatus.VALID.value diff --git a/geos-trame/tests/test_saving_attribute_modification.py b/geos-trame/tests/test_saving_attribute_modification.py index 7995ab6dc..cd744c2d9 100644 --- a/geos-trame/tests/test_saving_attribute_modification.py +++ b/geos-trame/tests/test_saving_attribute_modification.py @@ -5,7 +5,7 @@ import pytest from trame.app import get_server -from geos_trame.app.core import GeosTrame +from geos.trame.app.core import GeosTrame from trame_client.utils.testing import enable_testing from seleniumbase import SB @@ -13,6 +13,7 @@ from selenium.webdriver.common.by import By +# ruff: noqa @pytest.mark.skip( "Test to fix" ) @pytest.mark.parametrize( "server_path", [ "tests/utils/start_geos_trame_for_testing.py" ] ) def test_saving_attribute_modification( server, capsys ): diff --git a/geos-trame/tests/test_saving_node_modification.py b/geos-trame/tests/test_saving_node_modification.py index 336de0884..b95325778 100644 --- a/geos-trame/tests/test_saving_node_modification.py +++ b/geos-trame/tests/test_saving_node_modification.py @@ -5,7 +5,7 @@ import pytest from trame.app import get_server -from geos_trame.app.core import GeosTrame +from geos.trame.app.core import GeosTrame from trame_client.utils.testing import enable_testing from seleniumbase import SB @@ -13,6 +13,7 @@ from selenium.webdriver.common.by import By +# ruff: noqa @pytest.mark.skip( "Test to fix" ) @pytest.mark.parametrize( "server_path", [ "tests/utils/start_geos_trame_for_testing.py" ] ) def test_saving_node_modification( server, capsys ): diff --git a/geos-trame/tests/test_saving_subnode_modifications.py b/geos-trame/tests/test_saving_subnode_modifications.py index 6876738ae..5db64cd29 100644 --- a/geos-trame/tests/test_saving_subnode_modifications.py +++ b/geos-trame/tests/test_saving_subnode_modifications.py @@ -5,7 +5,7 @@ import pytest from trame.app import get_server -from geos_trame.app.core import GeosTrame +from geos.trame.app.core import GeosTrame from trame_client.utils.testing import enable_testing from seleniumbase import SB @@ -13,6 +13,7 @@ from selenium.webdriver.common.by import By +# ruff: noqa @pytest.mark.skip( "Test to fix" ) @pytest.mark.parametrize( "server_path", [ "tests/utils/start_geos_trame_for_testing.py" ] ) def test_saving_subnode_modifications( server, capsys ): diff --git a/geos-trame/tests/test_well_intersection.py b/geos-trame/tests/test_well_intersection.py index f45841da8..f6e2e0fde 100644 --- a/geos-trame/tests/test_well_intersection.py +++ b/geos-trame/tests/test_well_intersection.py @@ -1,74 +1,78 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Lucas Givord - Kitware -from trame.app import get_server -from trame_client.utils.testing import enable_testing -from geos_trame.app.core import GeosTrame +# ruff: noqa +from pathlib import Path +from trame_server import Server +from trame_server.state import State +from trame_vuetify.ui.vuetify3 import VAppLayout +from tests.trame_fixtures import trame_server_layout, trame_state -def test_internal_well_intersection(): +from geos.trame.app.core import GeosTrame - server = enable_testing( get_server( client_type="vue3" ), "message" ) - file_name = "geos-trame/tests/data/geosDeck/geosDeck.xml" - app = GeosTrame( server, file_name ) - app.state.ready() +def test_internal_well_intersection( trame_server_layout: tuple[ Server, VAppLayout ], trame_state: State ) -> None: + """Test internal well intersection.""" + root_path = Path( __file__ ).parent.absolute().__str__() + file_name = root_path + "/data/geosDeck/geosDeck.xml" - app.deckInspector.state.object_state = [ "Problem/Mesh/0/VTKMesh/0", True ] - app.deckInspector.state.flush() + app = GeosTrame( trame_server_layout[ 0 ], file_name ) - app.deckInspector.state.object_state = [ + trame_state.object_state = ( "Problem/Mesh/0/VTKMesh/0", True ) + trame_state.flush() + + trame_state.object_state = ( "Problem/Mesh/0/VTKMesh/0/InternalWell/0", True, - ] - app.deckInspector.state.flush() + ) + trame_state.flush() - app.deckInspector.state.object_state = [ + trame_state.object_state = ( "Problem/Mesh/0/VTKMesh/0/InternalWell/0/Perforation/0", True, - ] - app.deckInspector.state.flush() + ) + trame_state.flush() - app.deckInspector.state.object_state = [ + trame_state.object_state = ( "Problem/Mesh/0/VTKMesh/0/InternalWell/0/Perforation/1", True, - ] - app.deckInspector.state.flush() + ) + trame_state.flush() assert app.deckViewer.well_engine.get_number_of_wells() == 1 assert len( app.deckViewer._perforations ) == 2 -def test_vtk_well_intersection(): - - server = enable_testing( get_server( client_type="vue3" ), "message" ) - file_name = "geos-trame/tests/data/geosDeck/geosDeck.xml" +def test_vtk_well_intersection( trame_server_layout: tuple[ Server, VAppLayout ], trame_state: State ) -> None: + """Test vtk well intersection.""" + root_path = Path( __file__ ).parent.absolute().__str__() + file_name = root_path + "/data/geosDeck/geosDeck.xml" - app = GeosTrame( server, file_name ) - app.state.ready() + app = GeosTrame( trame_server_layout[ 0 ], file_name ) - app.deckInspector.state.object_state = [ "Problem/Mesh/0/VTKMesh/0", True ] - app.deckInspector.state.flush() + trame_state.object_state = ( "Problem/Mesh/0/VTKMesh/0", True ) + trame_state.flush() - app.deckInspector.state.object_state = [ "Problem/Mesh/0/VTKMesh/0/VTKWell/0", True ] - app.deckInspector.state.flush() + trame_state.object_state = ( "Problem/Mesh/0/VTKMesh/0/VTKWell/0", True ) + trame_state.flush() - app.deckInspector.state.object_state = [ + trame_state.object_state = ( "Problem/Mesh/0/VTKMesh/0/VTKWell/0/Perforation/0", True, - ] - app.deckInspector.state.flush() + ) + trame_state.flush() - app.deckInspector.state.object_state = [ + trame_state.object_state = ( "Problem/Mesh/0/VTKMesh/0/VTKWell/0/Perforation/1", True, - ] - app.deckInspector.state.flush() + ) + trame_state.flush() assert app.deckViewer.well_engine.get_number_of_wells() == 1 assert len( app.deckViewer._perforations ) == 2 - app.deckInspector.state.object_state = [ "Problem/Mesh/0/VTKMesh/0/VTKWell/0", False ] - app.deckInspector.state.flush() + trame_state.object_state = ( "Problem/Mesh/0/VTKMesh/0/VTKWell/0", False ) + trame_state.flush() assert app.deckViewer.well_engine.get_number_of_wells() == 0 diff --git a/geos-trame/tests/trame_fixtures.py b/geos-trame/tests/trame_fixtures.py new file mode 100644 index 000000000..0866dbd96 --- /dev/null +++ b/geos-trame/tests/trame_fixtures.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. +# SPDX-FileContributor: Kitware + +import pytest +from trame_server import Server +from trame_server.state import State +from trame_vuetify.ui.vuetify3 import VAppLayout +from typing import Generator + + +@pytest.fixture +def trame_server_layout() -> Generator[ tuple[ Server, VAppLayout ], None, None ]: + """Yield a test server and layout.""" + server = Server() + server.debug = True + + with VAppLayout( server ) as layout: + yield server, layout + + +@pytest.fixture +def trame_state( trame_server_layout: tuple[ Server, VAppLayout ] ) -> Generator[ State, None, None ]: + """Yield a test state.""" + trame_server_layout[ 0 ].state.ready() + yield trame_server_layout[ 0 ].state diff --git a/geos-trame/tests/utils/start_geos_trame_for_testing.py b/geos-trame/tests/utils/start_geos_trame_for_testing.py index 7de843861..73c6f8ba9 100644 --- a/geos-trame/tests/utils/start_geos_trame_for_testing.py +++ b/geos-trame/tests/utils/start_geos_trame_for_testing.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Lionel Untereiner from trame_client.utils.testing import enable_testing -from geos_trame.app.core import GeosTrame +from geos.trame.app.core import GeosTrame from trame.app import get_server from pathlib import Path diff --git a/geos-trame/tests/utils/testing_tools.py b/geos-trame/tests/utils/testing_tools.py index 7cef8984e..9d74e5c4f 100644 --- a/geos-trame/tests/utils/testing_tools.py +++ b/geos-trame/tests/utils/testing_tools.py @@ -1,24 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Lionel Untereiner -from PIL import Image -from PIL import ImageChops +from PIL import Image, ImageChops -def image_pixel_differences( base_image_path, compare_image_path ): - """ - Calculates the bounding box of the non-zero regions in the image. - :param base_image: target image to find - :param compare_image: set of images containing the target image +def image_pixel_differences( base_image_path: str, compare_image_path: str ) -> bool: + """Calculates the bounding box of the non-zero regions in the image. + + :param base_image_path: target image to find + :param compare_image_path: set of images containing the target image :return: True is the L1 value between each image is identitic, - False otherwise + False otherwise. """ - base_image = Image.open( base_image_path ) compare_image = Image.open( compare_image_path ) diff = ImageChops.difference( base_image, compare_image ) - if diff.getbbox(): - return False - else: - return True + return not diff.getbbox() diff --git a/geos-trame/vue-components/vite.config.js b/geos-trame/vue-components/vite.config.js index e66ebfe09..776e17b00 100644 --- a/geos-trame/vue-components/vite.config.js +++ b/geos-trame/vue-components/vite.config.js @@ -15,7 +15,7 @@ export default { }, }, }, - outDir: "../src/geos_trame/module/serve", + outDir: "../src/geos/trame/module/serve", assetsDir: ".", }, };