diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index a86ca1dbf6..f316adc160 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -61,7 +61,8 @@ TRTEngine::TRTEngine( const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata) + const std::string& serialized_metadata, + const ResourceAllocationStrategy resource_allocation_strategy) : TRTEngine( "deserialized_trt", serialized_engine, @@ -71,7 +72,8 @@ TRTEngine::TRTEngine( target_platform, hardware_compatible, requires_output_allocator, - serialized_metadata) {} + serialized_metadata, + resource_allocation_strategy) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -83,7 +85,10 @@ TRTEngine::TRTEngine(std::vector serialized_info) Platform(serialized_info[TARGET_PLATFORM_IDX]), static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX])), static_cast(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])), - serialized_info[SERIALIZED_METADATA_IDX]) {} + serialized_info[SERIALIZED_METADATA_IDX], + (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) + ? ResourceAllocationStrategy::kDynamic + : ResourceAllocationStrategy::kStatic)) {} TRTEngine::TRTEngine( const std::string& mod_name, @@ -94,7 +99,8 @@ TRTEngine::TRTEngine( const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata) { + const std::string& serialized_metadata, + const ResourceAllocationStrategy resource_allocation_strategy) { TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -124,7 +130,16 @@ TRTEngine::TRTEngine( cuda_engine->setWeightStreamingBudgetV2(budget_bytes); } - exec_ctx = make_trt(cuda_engine->createExecutionContext()); + this->resource_allocation_strategy = resource_allocation_strategy; + LOG_DEBUG( + "Resource allocation strategy: " + << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static")); + if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { + this->exec_ctx = + make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); + } else { + this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); + } TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context"); runtime_states.old_cudagraphs = CUDAGRAPHS_MODE; @@ -401,6 +416,7 @@ std::string TRTEngine::to_str() const { ss << " Device: " << device_info << std::endl; ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; ss << " Target Platform: " << target_platform << std::endl; + ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; // clang-format on return ss.str(); } @@ -444,7 +460,8 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]), std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]), std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), - std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX])); + std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), + std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])); } std::vector TRTEngine::serialize() { @@ -467,6 +484,8 @@ std::vector TRTEngine::serialize() { serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0"; serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata; serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize(); + serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = + this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; return serialized_info; } @@ -475,6 +494,20 @@ void TRTEngine::reset_captured_graph() { cudagraph.reset(); } +void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) { + if (new_strategy != this->resource_allocation_strategy) { + this->resource_allocation_strategy = new_strategy; + if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { + LOG_DEBUG("Setting resource allocation strategy to dynamic"); + this->exec_ctx = + make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); + } else { + LOG_DEBUG("Setting resource allocation strategy to static"); + this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); + } + } +} + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 7e0de52126..5a69fe9754 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -29,7 +29,8 @@ using FlattenedState = std::tuple< std::tuple, // HW compatibility std::tuple, // requires_output_allocator std::tuple, // serialized metadata - std::tuple>; // Platform + std::tuple, // Platform + std::tuple>; // Resource Allocation Strategy struct TorchTRTRuntimeStates { // Indicates whether CUDAGraphs were enabled in the previous execute_engine @@ -98,6 +99,8 @@ class DynamicOutputAllocator : public nvinfer1::IOutputAllocator { }; struct TRTEngine : torch::CustomClassHolder { + // Resource Allocation Strategy + typedef enum { kStatic = 0, kDynamic } ResourceAllocationStrategy; // Each engine needs it's own runtime object std::shared_ptr rt; std::shared_ptr cuda_engine; @@ -129,7 +132,9 @@ struct TRTEngine : torch::CustomClassHolder { const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = ""); + const std::string& serialized_metadata = "", + const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = + TRTEngine::ResourceAllocationStrategy::kStatic); TRTEngine(std::vector serialized_info); @@ -142,7 +147,9 @@ struct TRTEngine : torch::CustomClassHolder { const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = ""); + const std::string& serialized_metadata = "", + const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = + TRTEngine::ResourceAllocationStrategy::kStatic); TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; @@ -203,6 +210,9 @@ struct TRTEngine : torch::CustomClassHolder { std::string cuda_graph_debug_path; std::mutex mu; std::unique_ptr trt_engine_profiler; + ResourceAllocationStrategy resource_allocation_strategy = kStatic; + void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); + ResourceAllocationStrategy get_resource_allocation_strategy(); }; } // namespace runtime diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 54e9701c9e..8338fde257 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -201,6 +201,12 @@ void create_output_allocator(c10::intrusive_ptr compiled_engine) { } std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine) { + torch::Tensor dynamic_workspace; + if (compiled_engine->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { + dynamic_workspace = torch::empty(compiled_engine->cuda_engine->getDeviceMemorySizeV2(), {torch::kCUDA}); + compiled_engine->exec_ctx->setDeviceMemory(dynamic_workspace.data_ptr()); + } + auto run_standard_execution = [&]() { bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); bool shape_changed = _validate_shapes(inputs, compiled_engine); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 49cd12f86a..9baa0df32c 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -92,6 +92,13 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("reset_captured_graph", &TRTEngine::reset_captured_graph) .def("set_output_tensors_as_unowned", &TRTEngine::set_output_tensors_as_unowned) .def("are_output_tensors_unowned", &TRTEngine::are_output_tensors_unowned) + .def( + "use_dynamically_allocated_resources", + [](const c10::intrusive_ptr& self, bool dynamic) -> void { + self->set_resource_allocation_strategy( + dynamic ? TRTEngine::ResourceAllocationStrategy::kDynamic + : TRTEngine::ResourceAllocationStrategy::kStatic); + }) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) .def_property( @@ -104,6 +111,10 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = [](const c10::intrusive_ptr& self) -> std::vector { return self->serialize(); }, [](std::vector serialized_info) -> c10::intrusive_ptr { serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]); + LOG_DEBUG( + "Deserialized resource allocation strategy: " + << (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? "Dynamic" + : "Static")); TRTEngine::verify_serialization_fmt(serialized_info); return c10::make_intrusive(serialized_info); }); @@ -137,6 +148,7 @@ TORCH_LIBRARY(tensorrt, m) { m.def("TARGET_PLATFORM_IDX", []() -> int64_t { return TARGET_PLATFORM_IDX; }); m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; }); m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); + m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; }); m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 894df55bfe..d8f71683d3 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -16,7 +16,7 @@ namespace core { namespace runtime { using EngineID = int64_t; -const std::string ABI_VERSION = "7"; +const std::string ABI_VERSION = "8"; extern bool MULTI_DEVICE_SAFE_MODE; typedef enum { @@ -38,6 +38,7 @@ typedef enum { SERIALIZED_METADATA_IDX, TARGET_PLATFORM_IDX, REQUIRES_OUTPUT_ALLOCATOR_IDX, + RESOURCE_ALLOCATION_STRATEGY_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; diff --git a/examples/dynamo/dynamic_memory_allocation.py b/examples/dynamo/dynamic_memory_allocation.py new file mode 100644 index 0000000000..fe64e0e3b7 --- /dev/null +++ b/examples/dynamo/dynamic_memory_allocation.py @@ -0,0 +1,45 @@ +# %% +import gc +import time + +import numpy as np +import torch +import torch_tensorrt as torch_trt +import torchvision.models as models + +np.random.seed(5) +torch.manual_seed(5) +inputs = [torch.rand((100, 3, 224, 224)).to("cuda")] + +settings = { + "ir": "dynamo", + "use_python_runtime": False, + "enabled_precisions": {torch.float32}, + "immutable_weights": False, + "lazy_engine_init": True, + "dynamically_allocate_resources": True, +} + +model = models.resnet152(pretrained=True).eval().to("cuda") +compiled_module = torch_trt.compile(model, inputs=inputs, **settings) +print((torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3) +compiled_module(*inputs) + + +time.sleep(30) +with torch_trt.dynamo.runtime.ResourceAllocationStrategy( + compiled_module, dynamically_allocate_resources=False +): + print( + "Memory used (GB):", + (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3, + ) + compiled_module(*inputs) + gc.collect() + torch.cuda.empty_cache() + time.sleep(30) + print( + "Memory used (GB):", + (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3, + ) + compiled_module(*inputs) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index cbde956a88..82532fd1f3 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -110,6 +110,7 @@ def cross_compile_for_windows( use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, + dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -186,6 +187,7 @@ def cross_compile_for_windows( use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited. cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. + dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -343,6 +345,7 @@ def cross_compile_for_windows( "use_distributed_mode_trace": use_distributed_mode_trace, "enable_resource_partitioning": enable_resource_partitioning, "cpu_memory_budget": cpu_memory_budget, + "dynamically_allocate_resources": dynamically_allocate_resources, } # disable the following settings is not supported for cross compilation for windows feature @@ -459,6 +462,7 @@ def compile( ] = _defaults.AUTOCAST_CALIBRATION_DATALOADER, cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, + dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -545,6 +549,7 @@ def compile( autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited. cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. + dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -747,6 +752,7 @@ def compile( "autocast_calibration_dataloader": autocast_calibration_dataloader, "enable_resource_partitioning": enable_resource_partitioning, "cpu_memory_budget": cpu_memory_budget, + "dynamically_allocate_resources": dynamically_allocate_resources, } logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index fd093d402f..dbf94243a2 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -66,6 +66,7 @@ AUTOCAST_CALIBRATION_DATALOADER = None ENABLE_RESOURCE_PARTITIONING = False CPU_MEMORY_BUDGET = None +DYNAMICALLY_ALLOCATE_RESOURCES = False if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 7c20f44607..82fde7b41a 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -20,6 +20,7 @@ DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, + DYNAMICALLY_ALLOCATE_RESOURCES, DRYRUN, ENABLE_AUTOCAST, ENABLE_CROSS_COMPILE_FOR_WINDOWS, @@ -115,6 +116,8 @@ class CompilationSettings: autocast_max_output_threshold (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. If not provided, infinity will be used. Default is None. autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. + offload_module_to_cpu (bool): Offload the model to CPU to reduce memory footprint during compilation + dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -173,6 +176,7 @@ class CompilationSettings: ) enable_resource_partitioning: bool = ENABLE_RESOURCE_PARTITIONING cpu_memory_budget: int = CPU_MEMORY_BUDGET + dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( diff --git a/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py b/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py new file mode 100644 index 0000000000..3e570f4d78 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py @@ -0,0 +1,36 @@ +from typing import Any + +import torch + + +class ResourceAllocationStrategy(torch.nn.Module): # type: ignore[misc] + """ + ResourceAllocationStrategy is a context manager module that temporarily enables dynamic resource allocation + for all TRT submodules of the given compiled_module. When entering the context, + it sets these submodules to use dynamically allocated resources. Upon exiting, it restores them to their + original (static) resource allocation mode. + """ + + def __init__( + self, + compiled_module: torch.nn.Module, + dynamically_allocate_resources: bool = True, + ) -> None: + super(ResourceAllocationStrategy, self).__init__() + self.compiled_module = compiled_module + self.dynamically_allocate_resources = dynamically_allocate_resources + + def __enter__(self) -> None: + print("Entering resource allocator context") + for name, submodule in self.compiled_module.named_modules(): + if "_run_on_acc" in name: + submodule.use_dynamically_allocated_resources( + dynamically_allocate_resources=self.dynamically_allocate_resources + ) + + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: + for name, submodule in self.compiled_module.named_modules(): + if "_run_on_acc" in name: + submodule.use_dynamically_allocated_resources( + dynamically_allocate_resources=self.dynamically_allocate_resources + ) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 23c372167d..7ed48fdb7f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -50,7 +50,10 @@ REQUIRES_OUTPUT_ALLOCATOR_IDX = ( torch.ops.tensorrt.REQUIRES_OUTPUT_ALLOCATOR_IDX() ) # 9 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 10 + RESOURCE_ALLOCATION_STRATEGY_IDX = ( + torch.ops.tensorrt.RESOURCE_ALLOCATION_STRATEGY_IDX() + ) # 10 + SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 11 @for_all_methods(needs_torch_tensorrt_runtime) @@ -139,6 +142,7 @@ def __init__( self.serialized_engine = serialized_engine self.engine = None self.requires_output_allocator = requires_output_allocator + self.dynamically_allocate_resources = settings.dynamically_allocate_resources if ( serialized_engine @@ -189,6 +193,12 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = str( int(self.requires_output_allocator) ) + logger.info( + f"PROVIDED RESOURCE ALLOCATION STRATEGY: {self.dynamically_allocate_resources}" + ) + engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str( + int(self.dynamically_allocate_resources) + ) return engine_info @@ -217,6 +227,14 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: def _reset_captured_graph(self) -> None: self.engine.reset_captured_graph() + def use_dynamically_allocated_resources( + self, dynamically_allocate_resources: bool = False + ) -> None: + self.dynamically_allocate_resources = dynamically_allocate_resources + self.engine.use_dynamically_allocated_resources( + self.dynamically_allocate_resources + ) + def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. diff --git a/py/torch_tensorrt/dynamo/runtime/__init__.py b/py/torch_tensorrt/dynamo/runtime/__init__.py index de47d942e9..0eb66b24b0 100644 --- a/py/torch_tensorrt/dynamo/runtime/__init__.py +++ b/py/torch_tensorrt/dynamo/runtime/__init__.py @@ -2,6 +2,9 @@ from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( # noqa: F401 PythonTorchTensorRTModule, ) +from torch_tensorrt.dynamo.runtime._ResourceAllocator import ( # noqa: F401 + ResourceAllocationStrategy, +) from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: F401 TorchTensorRTModule, ) diff --git a/tests/py/dynamo/runtime/test_005_dynamic_allocation.py b/tests/py/dynamo/runtime/test_005_dynamic_allocation.py new file mode 100644 index 0000000000..efdc13c284 --- /dev/null +++ b/tests/py/dynamo/runtime/test_005_dynamic_allocation.py @@ -0,0 +1,52 @@ +import importlib +import unittest + +import torch +import torch.nn.functional as F +import torch_tensorrt as torch_trt +from torch import nn +from torch.testing._internal.common_utils import TestCase +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +class TestDynamicAllocation(TestCase): + def test_dynamic_allocation(self): + + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 3, 1) + self.conv2 = nn.Conv2d(6, 16, 3, 1) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + return x + + inputs = [torch.rand((100, 3, 224, 224)).to("cuda")] + + settings = { + "ir": "dynamo", + "use_python_runtime": False, + "enabled_precisions": {torch.float32}, + "immutable_weights": False, + "lazy_engine_init": True, + "dynamically_allocate_resources": True, + } + + model = net().eval().to("cuda") + compiled_module = torch_trt.compile(model, inputs=inputs, **settings) + compiled_module(*inputs) + + # Inference on PyTorch model + model_output = model(*inputs) + cos_sim = cosine_similarity(model_output, compiled_module(*inputs)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset()