Skip to content

trestle.core.generators

trestle.core.generators ¤

Capabilities to allow the generation of various oscal objects.

Attributes¤

TG = TypeVar('TG', bound=OscalBaseModel) module-attribute ¤

logger = logging.getLogger(__name__) module-attribute ¤

sample_base64 = Base64(filename=(const.REPLACE_ME), **{'media-type': const.REPLACE_ME}, value=sample_base64_value) module-attribute ¤

sample_base64_value = 0 module-attribute ¤

sample_date_value = '2400-02-29' module-attribute ¤

sample_method = Methods.EXAMINE module-attribute ¤

sample_observation_type_valid_value = ObservationTypeValidValues.historic module-attribute ¤

sample_task_valid_value = TaskValidValues.milestone module-attribute ¤

type_base64 = type(sample_base64) module-attribute ¤

Classes¤

Functions¤

generate_sample_model(model, include_optional=False, depth=-1) ¤

Given a model class, generate an object of that class with sample values.

Can generate optional variables with an enabled flag. Any array objects will have a single entry injected into it.

Note: Trestle generate will not activate recursive loops irrespective of the depth flag.

Parameters:

Name Type Description Default
model Union[Type[TG], List[TG], Dict[str, TG]]

The model type provided. Typically for a user as an OscalBaseModel Subclass.

required
include_optional bool

Whether or not to generate optional fields.

False
depth int

Depth of the tree at which optional fields are generated. Negative values (default) removes the limit.

-1

Returns:

Type Description
TG

The generated instance with a pro-forma values filled out as best as possible.

Source code in trestle/core/generators.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def generate_sample_model(
    model: Union[Type[TG], List[TG], Dict[str, TG]], include_optional: bool = False, depth: int = -1
) -> TG:
    """Given a model class, generate an object of that class with sample values.

    Can generate optional variables with an enabled flag. Any array objects will have a single entry injected into it.

    Note: Trestle generate will not activate recursive loops irrespective of the depth flag.

    Args:
        model: The model type provided. Typically for a user as an OscalBaseModel Subclass.
        include_optional: Whether or not to generate optional fields.
        depth: Depth of the tree at which optional fields are generated. Negative values (default) removes the limit.

    Returns:
        The generated instance with a pro-forma values filled out as best as possible.
    """
    effective_optional = include_optional and not depth == 0

    model_type = model
    # This block normalizes model type down to
    if utils.is_collection_field_type(model):
        model_type = utils.get_origin(model)
        model = utils.get_inner_type(model)

    # Handle Union types at the top level (e.g., when model is Union[Parameter1, Parameter2])
    # This can happen when get_inner_type returns a Union from a list[Union[...]]
    origin = utils.get_origin(model)
    is_union = origin == Union or str(origin) == "<class 'types.UnionType'>"
    if is_union:
        union_args = typing.get_args(model)
        # Find first non-None OscalBaseModel type in the union
        for arg in union_args:
            if arg is not type(None) and safe_is_sub(arg, OscalBaseModel):
                model = arg
                break
        else:
            # If no OscalBaseModel found, use first non-None type
            model = next((arg for arg in union_args if arg is not type(None)), union_args[0])

    model = cast(TG, model)

    model_dict = {}
    # this block is needed to avoid situations where an inbuilt is inside a list / dict.
    # the only time dict ever appears is with include_all, which is handled specially
    # the only type of collection possible after OSCAL 1.0.0 is list
    if safe_is_sub(model, OscalBaseModel):
        for field in model.__fields__:
            if model_type in [OscalVersion]:
                model_dict[field] = OSCAL_VERSION
                break
            # Special handling for include_all field - only skip if it's optional
            if field == 'include_all':
                if model.__fields__[field].required:  # type: ignore
                    # Field is required, generate it
                    model_dict[field] = {}
                elif include_optional:
                    # Field is optional and we want to include optional fields
                    model_dict[field] = {}
                continue
            outer_type = model.__fields__[field].outer_type_  # type: ignore

            # Skip fields with unresolved ForwardRefs, but if required, provide empty list
            if isinstance(outer_type, (str, ForwardRef)):
                # If it's a required field, we need to provide something
                # Assume it's a list type and provide an empty list
                if model.__fields__[field].required:  # type: ignore
                    model_dict[field] = []
                continue

            # Handle both typing.Union and types.UnionType (Python 3.10+ uses | operator)
            origin = utils.get_origin(outer_type)
            is_union = origin == Union or str(origin) == "<class 'types.UnionType'>"
            if is_union:
                # For Union types, prefer Enum types over other types for sample generation
                # This handles fields like Union[ConstrainedStr, Enum, None]
                union_args = typing.get_args(outer_type)
                enum_type = None
                for arg in union_args:
                    if arg is not type(None) and safe_is_sub(arg, Enum):
                        enum_type = arg
                        break
                # Use the enum type if found, otherwise fall back to first non-None, non-ForwardRef type
                if enum_type:
                    outer_type = enum_type
                else:
                    # Get first non-None, non-ForwardRef type
                    # Skip ForwardRef types as they haven't been resolved yet
                    outer_type = None
                    for arg in union_args:
                        # Check if arg is not None and not a ForwardRef
                        if arg is not type(None) and not isinstance(arg, (str, ForwardRef)):
                            outer_type = arg
                            break
                    if outer_type is None:
                        # If all types are ForwardRefs or None, skip this field
                        continue
            if model.__fields__[field].required or effective_optional:  # type: ignore
                # FIXME could be ForwardRef('SystemComponentStatus')
                if utils.is_collection_field_type(outer_type):
                    inner_type = utils.get_inner_type(outer_type)
                    # Check for circular reference: inner_type might be a Union containing model
                    if inner_type == model:
                        continue
                    # Also check if inner_type is a Union and model is one of its variants
                    inner_origin = utils.get_origin(inner_type)
                    is_inner_union = inner_origin == Union or str(inner_origin) == "<class 'types.UnionType'>"
                    if is_inner_union:
                        union_args = typing.get_args(inner_type)
                        if model in union_args:
                            continue  # Circular reference detected
                    # Skip recursion if depth is 0 (but allow -1 for unlimited)
                    if depth == 0:
                        model_dict[field] = []
                    else:
                        model_dict[field] = generate_sample_model(
                            outer_type, include_optional=include_optional, depth=depth - 1
                        )
                elif is_by_type(outer_type):
                    model_dict[field] = generate_sample_value_by_type(outer_type, field)
                elif safe_is_sub(outer_type, OscalBaseModel):
                    # Skip recursion if depth is 0 (but allow -1 for unlimited)
                    # But always generate required fields even at depth 0
                    if depth == 0 and not model.__fields__[field].required:  # type: ignore
                        continue  # Skip optional nested models at depth 0
                    else:
                        model_dict[field] = generate_sample_model(
                            outer_type, include_optional=include_optional, depth=depth - 1
                        )
                else:
                    # Handle special cases (hacking)
                    if model_type in [Base64Datatype]:
                        model_dict[field] = sample_base64_value
                    elif model_type in [Base64]:
                        if field == 'filename':
                            model_dict[field] = sample_base64.filename
                        elif field == 'media_type':
                            model_dict[field] = sample_base64.media_type
                        elif field == 'value':
                            model_dict[field] = sample_base64.value
                    elif model_type in [DateDatatype]:
                        model_dict[field] = sample_date_value
                    # Hacking here:
                    # Root models should ideally not exist, however, sometimes we are stuck with them.
                    # If that is the case we need sufficient information on the type in order to generate a model.
                    # E.g. we need the type of the container.
                    elif field == '__root__' and hasattr(model, '__name__'):
                        model_dict[field] = generate_sample_value_by_type(
                            outer_type, str_utils.classname_to_alias(model.__name__, AliasMode.FIELD)
                        )
                    else:
                        model_dict[field] = generate_sample_value_by_type(outer_type, field)
        # Note: this assumes list constrains in oscal are always 1 as a minimum size. if two this may still fail.
    else:
        if model_type is list:
            return [generate_sample_value_by_type(model, '')]  # type: ignore
        if model_type is dict:
            return {const.REPLACE_ME: generate_sample_value_by_type(model, '')}  # type: ignore
        raise err.TrestleError('Unhandled collection type.')
    if model_type is list:
        return [model(**model_dict)]  # type: ignore
    if model_type is dict:
        return {const.REPLACE_ME: model(**model_dict)}  # type: ignore
    return model(**model_dict)  # type: ignore

generate_sample_value_by_type(type_, field_name) ¤

Given a type, return sample value.

Includes the Optional use of passing down a parent_model

Source code in trestle/core/generators.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def generate_sample_value_by_type(type_: type, field_name: str) -> Union[datetime, bool, int, str, float, Enum, Base64]:
    """Given a type, return sample value.

    Includes the Optional use of passing down a parent_model
    """
    # FIXME: Should be in separate generator module as it inherits EVERYTHING
    if is_enum_method(type_):
        return sample_method
    if is_enum_task_valid_value(type_):
        return sample_task_valid_value
    if is_enum_observation_type_valid_value(type_):
        return sample_observation_type_valid_value
    if type_ is Base64:
        return sample_base64
    if type_ is datetime:
        return datetime.now().astimezone()
    if type_ is bool:
        return False
    if type_ is int:
        return 0
    if type_ is float:
        return 0.00
    if safe_is_sub(type_, ConstrainedStr) or (hasattr(type_, '__name__') and 'ConstrainedStr' in type_.__name__):
        # This code here is messy. we need to meet a set of constraints. If we do
        # TODO: handle regex directly
        if 'uuid' == field_name:
            return str(uuid.uuid4())
        # some things like location_uuid in lists arrive here with field_name=''
        if type_.regex and type_.regex.pattern.startswith('^[0-9A-Fa-f]{8}'):
            return const.SAMPLE_UUID_STR
        if field_name == 'date_authorized':
            return str(date.today().isoformat())
        if field_name == 'oscal_version':
            return OSCAL_VERSION
        if 'uuid' in field_name:
            return const.SAMPLE_UUID_STR
        # Only case where are UUID is required but not in name.
        if field_name.rstrip('s') == 'member_of_organization':
            return const.SAMPLE_UUID_STR
        return const.REPLACE_ME
    if hasattr(type_, '__name__') and 'ConstrainedIntValue' in type_.__name__:
        # create an int value as close to the floor as possible does not test upper bound
        multiple = type_.multiple_of if type_.multiple_of else 1  # default to every integer
        # this command is a bit of a problem
        floor = type_.ge if type_.ge else 0
        floor = type_.gt + 1 if type_.gt else floor
        if math.remainder(floor, multiple) == 0:
            return floor
        return (floor + 1) * multiple
    if safe_is_sub(type_, Enum):
        # keys and values diverge due to hypens in oscal names
        return type_(list(type_.__members__.values())[0])
    if type_ is str:
        if field_name == 'oscal_version':
            return OSCAL_VERSION
        return const.REPLACE_ME
    if type_ is pydantic.v1.networks.EmailStr:
        return pydantic.v1.networks.EmailStr('dummy@sample.com')
    if type_ is pydantic.v1.networks.AnyUrl:
        # TODO: Cleanup: this should be usable from a url.. but it's not inuitive.
        return pydantic.v1.networks.AnyUrl('https://sample.com/replaceme.html', scheme='http', host='sample.com')
    if type_ is list:
        raise err.TrestleError(f'Unable to generate sample for type {type_}')
    # default to empty dict for dict types, string for anything else
    # Check for both dict and generic dict types like dict[str, Any]
    if type_ is dict or (hasattr(type_, '__origin__') and type_.__origin__ is dict):
        return {}  # type: ignore[return-value]
    return const.REPLACE_ME

is_by_type(model_type) ¤

Check for by type.

Source code in trestle/core/generators.py
177
178
179
180
181
182
def is_by_type(model_type: Union[Type[TG], List[TG], Dict[str, TG]]) -> bool:
    """Check for by type."""
    rval = False
    if model_type == type_base64:
        rval = True
    return rval

is_enum_method(type_) ¤

Test for method.

Source code in trestle/core/generators.py
71
72
73
74
75
76
77
78
79
80
def is_enum_method(type_: type) -> bool:
    """Test for method."""
    rval = False
    if utils.get_origin(type_) == Union:
        args = typing.get_args(type_)
        for arg in args:
            if "<enum 'Methods'>" == f'{arg}':
                rval = True
                break
    return rval

is_enum_observation_type_valid_value(type_) ¤

Test for observation type valid value.

Source code in trestle/core/generators.py
 95
 96
 97
 98
 99
100
101
102
103
104
def is_enum_observation_type_valid_value(type_: type) -> bool:
    """Test for observation type valid value."""
    rval = False
    if utils.get_origin(type_) == Union:
        args = typing.get_args(type_)
        for arg in args:
            if "<enum 'ObservationTypeValidValues'>" == f'{arg}':
                rval = True
                break
    return rval

is_enum_task_valid_value(type_) ¤

Test for task valid value.

Source code in trestle/core/generators.py
83
84
85
86
87
88
89
90
91
92
def is_enum_task_valid_value(type_: type) -> bool:
    """Test for task valid value."""
    rval = False
    if utils.get_origin(type_) == Union:
        args = typing.get_args(type_)
        for arg in args:
            if "<enum 'TaskValidValues'>" == f'{arg}':
                rval = True
                break
    return rval

safe_is_sub(sub, parent) ¤

Is this a subclass of parent.

Source code in trestle/core/generators.py
60
61
62
63
64
65
66
67
68
def safe_is_sub(sub: Any, parent: Any) -> bool:
    """Is this a subclass of parent."""
    # Handle Python 3.10+ generic types (e.g., dict[str, Any])
    # These are types.GenericAlias and cannot be used with issubclass()
    if hasattr(sub, '__origin__'):
        # For generic types like dict[str, Any], check the origin (dict)
        sub = typing.get_origin(sub)
    is_class = inspect.isclass(sub)
    return is_class and issubclass(sub, parent)

handler: python