Base Sampler
Base sampler class for diffusion models.
This module provides a base abstract class for all samplers used in diffusion models. It defines the common interface that all samplers should implement.
BaseSampler
¶
Bases: ABC
Abstract base class for all diffusion model samplers.
All samplers inherit from this class and must implement the call method which performs the actual sampling process.
Attributes:
| Name | Type | Description |
|---|---|---|
diffusion |
The diffusion model to sample from. |
|
verbose |
Whether to print progress information during sampling. |
Source code in image_gen\samplers\base.py
__call__(x_T, score_model, *args, n_steps=500, seed=None, callback=None, callback_frequency=50, guidance=None, **kwargs)
abstractmethod
¶
Perform the sampling process.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_T
|
Tensor
|
The initial noise tensor to start sampling from. |
required |
score_model
|
Callable
|
The score model function that predicts the score. |
required |
*args
|
Any
|
Additional positional arguments. |
()
|
n_steps
|
int
|
Number of sampling steps. Defaults to 500. |
500
|
seed
|
Optional[int]
|
Random seed for reproducibility. Defaults to None. |
None
|
callback
|
Optional[Callable[[Tensor, int], None]]
|
Optional function called during sampling to monitor progress. It takes the current sample and step number as inputs. Defaults to None. |
None
|
callback_frequency
|
int
|
How often to call the callback function. Defaults to 50. |
50
|
guidance
|
Optional[Callable[[Tensor, Tensor, Tensor], Tensor]]
|
Optional guidance function for conditional sampling. Defaults to None. |
None
|
**kwargs
|
Any
|
Additional keyword arguments. |
{}
|
Returns:
| Type | Description |
|---|---|
Tensor
|
A tuple containing the final sample and the sequence of all samples. |
Source code in image_gen\samplers\base.py
__init__(diffusion, *_, verbose=True, **__)
¶
Initialize the sampler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
diffusion
|
BaseDiffusion
|
The diffusion model to sample from. |
required |
verbose
|
bool
|
Whether to print progress information during sampling. Defaults to True. |
True
|
Source code in image_gen\samplers\base.py
__str__()
¶
Return a string representation of the sampler.
Returns:
| Type | Description |
|---|---|
str
|
A string with the sampler's class name and its configuration. |
Source code in image_gen\samplers\base.py
config()
¶
Return the configuration of the sampler.
Returns:
| Type | Description |
|---|---|
dict
|
A dictionary with the sampler's configuration parameters. |