Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,11 +1552,11 @@ def __init__(
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")

self.blocks = blocks
self._blocks = blocks
self._components_manager = components_manager
self._collection = collection
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
self._component_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_configs}

# update component_specs and config_specs based on modular_model_index.json
if modular_config_dict is not None:
Expand Down Expand Up @@ -1603,7 +1603,9 @@ def __init__(
for name, config_spec in self._config_specs.items():
default_configs[name] = config_spec.default
self.register_to_config(**default_configs)
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
self.register_to_config(
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
)

@property
def default_call_parameters(self) -> Dict[str, Any]:
Expand All @@ -1612,7 +1614,7 @@ def default_call_parameters(self) -> Dict[str, Any]:
- Dictionary mapping input names to their default values
"""
params = {}
for input_param in self.blocks.inputs:
for input_param in self._blocks.inputs:
params[input_param.name] = input_param.default
return params

Expand Down Expand Up @@ -1775,7 +1777,15 @@ def doc(self):
Returns:
- The docstring of the pipeline blocks
"""
return self.blocks.doc
return self._blocks.doc

@property
def blocks(self) -> ModularPipelineBlocks:
"""
Returns:
- A copy of the pipeline blocks
"""
return deepcopy(self._blocks)

def register_components(self, **kwargs):
"""
Expand Down Expand Up @@ -2509,7 +2519,7 @@ def _dict_to_component_spec(
)

def set_progress_bar_config(self, **kwargs):
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
for sub_block_name, sub_block in self._blocks.sub_blocks.items():
if hasattr(sub_block, "set_progress_bar_config"):
sub_block.set_progress_bar_config(**kwargs)

Expand Down Expand Up @@ -2563,7 +2573,7 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =

# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
for expected_input_param in self.blocks.inputs:
for expected_input_param in self._blocks.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
Expand All @@ -2582,9 +2592,9 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
# Run the pipeline
with torch.no_grad():
try:
_, state = self.blocks(self, state)
_, state = self._blocks(self, state)
except Exception:
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
error_msg = f"Error in block: ({self._blocks.__class__.__name__}):\n"
logger.error(error_msg)
raise

Expand Down
Loading