Refactor constraints so that each constraint is it's own class (#23753)
diff --git a/src/controller/python/BUILD.gn b/src/controller/python/BUILD.gn
index 9108c12..a0381f0 100644
--- a/src/controller/python/BUILD.gn
+++ b/src/controller/python/BUILD.gn
@@ -230,6 +230,7 @@
"chip/utils/CommissioningBuildingBlocks.py",
"chip/utils/__init__.py",
"chip/yaml/__init__.py",
+ "chip/yaml/constraints.py",
"chip/yaml/data_model_lookup.py",
"chip/yaml/errors.py",
"chip/yaml/format_converter.py",
diff --git a/src/controller/python/chip/yaml/constraints.py b/src/controller/python/chip/yaml/constraints.py
new file mode 100644
index 0000000..933d74f
--- /dev/null
+++ b/src/controller/python/chip/yaml/constraints.py
@@ -0,0 +1,218 @@
+#
+# Copyright (c) 2022 Project CHIP Authors
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABC, abstractmethod
+import chip.yaml.format_converter as Converter
+from .variable_storage import VariableStorage
+
+
+class ConstraintValidationError(Exception):
+ def __init__(self, message):
+ super().__init__(message)
+
+
+class BaseConstraint(ABC):
+ '''Constrain Interface'''
+
+ @abstractmethod
+ def is_met(self, response) -> bool:
+ pass
+
+
+class _LoadableConstraint(BaseConstraint):
+ '''Constraints where value might be stored in VariableStorage needing runtime load.'''
+
+ def __init__(self, value, field_type, variable_storage: VariableStorage):
+ self._variable_storage = variable_storage
+ # When not none _indirect_value_key is binding a name to the constraint value, and the
+ # actual value can only be looked-up dynamically, which is why this is a key name.
+ self._indirect_value_key = None
+ self._value = None
+
+ if value is None:
+ # Default values set above is all we need here.
+ return
+
+ if isinstance(value, str) and self._variable_storage.is_key_saved(value):
+ self._indirect_value_key = value
+ else:
+ self._value = Converter.convert_yaml_type(
+ value, field_type)
+
+ def get_value(self):
+ '''Gets the current value of the constraint.
+
+ This method accounts for getting the runtime saved value from DUT previous responses.
+ '''
+ if self._indirect_value_key:
+ return self._variable_storage.load(self._indirect_value_key)
+ return self._value
+
+
+class _ConstraintHasValue(BaseConstraint):
+ def __init__(self, has_value):
+ self._has_value = has_value
+
+ def is_met(self, response) -> bool:
+ raise ConstraintValidationError('HasValue constraint currently not implemented')
+
+
+class _ConstraintType(BaseConstraint):
+ def __init__(self, type):
+ self._type = type
+
+ def is_met(self, response) -> bool:
+ raise ConstraintValidationError('Type constraint currently not implemented')
+
+
+class _ConstraintStartsWith(BaseConstraint):
+ def __init__(self, starts_with):
+ self._starts_with = starts_with
+
+ def is_met(self, response) -> bool:
+ return response.startswith(self._starts_with)
+
+
+class _ConstraintEndsWith(BaseConstraint):
+ def __init__(self, ends_with):
+ self._ends_with = ends_with
+
+ def is_met(self, response) -> bool:
+ return response.endswith(self._ends_with)
+
+
+class _ConstraintIsUpperCase(BaseConstraint):
+ def __init__(self, is_upper_case):
+ self._is_upper_case = is_upper_case
+
+ def is_met(self, response) -> bool:
+ return response.isupper() == self._is_upper_case
+
+
+class _ConstraintIsLowerCase(BaseConstraint):
+ def __init__(self, is_lower_case):
+ self._is_lower_case = is_lower_case
+
+ def is_met(self, response) -> bool:
+ return response.islower() == self._is_lower_case
+
+
+class _ConstraintMinValue(_LoadableConstraint):
+ def __init__(self, min_value, field_type, variable_storage: VariableStorage):
+ super().__init__(min_value, field_type, variable_storage)
+
+ def is_met(self, response) -> bool:
+ min_value = self.get_value()
+ return response >= min_value
+
+
+class _ConstraintMaxValue(_LoadableConstraint):
+ def __init__(self, max_value, field_type, variable_storage: VariableStorage):
+ super().__init__(max_value, field_type, variable_storage)
+
+ def is_met(self, response) -> bool:
+ max_value = self.get_value()
+ return response <= max_value
+
+
+class _ConstraintContains(BaseConstraint):
+ def __init__(self, contains):
+ self._contains = contains
+
+ def is_met(self, response) -> bool:
+ return set(self._contains).issubset(response)
+
+
+class _ConstraintExcludes(BaseConstraint):
+ def __init__(self, excludes):
+ self._excludes = excludes
+
+ def is_met(self, response) -> bool:
+ return set(self._excludes).isdisjoint(response)
+
+
+class _ConstraintHasMaskSet(BaseConstraint):
+ def __init__(self, has_masks_set):
+ self._has_masks_set = has_masks_set
+
+ def is_met(self, response) -> bool:
+ return all([(response & mask) == mask for mask in self._has_masks_set])
+
+
+class _ConstraintHasMaskClear(BaseConstraint):
+ def __init__(self, has_masks_clear):
+ self._has_masks_clear = has_masks_clear
+
+ def is_met(self, response) -> bool:
+ return all([(response & mask) == 0 for mask in self._has_masks_clear])
+
+
+class _ConstraintNotValue(_LoadableConstraint):
+ def __init__(self, not_value, field_type, variable_storage: VariableStorage):
+ super().__init__(not_value, field_type, variable_storage)
+
+ def is_met(self, response) -> bool:
+ not_value = self.get_value()
+ return response != not_value
+
+
+def get_constraints(constraints, field_type,
+ variable_storage: VariableStorage) -> list[BaseConstraint]:
+ _constraints = []
+ if 'hasValue' in constraints:
+ _constraints.append(_ConstraintHasValue(constraints.get('hasValue')))
+
+ if 'type' in constraints:
+ _constraints.append(_ConstraintType(constraints.get('type')))
+
+ if 'startsWith' in constraints:
+ _constraints.append(_ConstraintStartsWith(constraints.get('startsWith')))
+
+ if 'endsWith' in constraints:
+ _constraints.append(_ConstraintEndsWith(constraints.get('endsWith')))
+
+ if 'isUpperCase' in constraints:
+ _constraints.append(_ConstraintIsUpperCase(constraints.get('isUpperCase')))
+
+ if 'isLowerCase' in constraints:
+ _constraints.append(_ConstraintIsLowerCase(constraints.get('isLowerCase')))
+
+ if 'minValue' in constraints:
+ _constraints.append(_ConstraintMinValue(
+ constraints.get('minValue'), field_type, variable_storage))
+
+ if 'maxValue' in constraints:
+ _constraints.append(_ConstraintMaxValue(
+ constraints.get('maxValue'), field_type, variable_storage))
+
+ if 'contains' in constraints:
+ _constraints.append(_ConstraintContains(constraints.get('contains')))
+
+ if 'excludes' in constraints:
+ _constraints.append(_ConstraintExcludes(constraints.get('excludes')))
+
+ if 'hasMasksSet' in constraints:
+ _constraints.append(_ConstraintHasMaskSet(constraints.get('hasMasksSet')))
+
+ if 'hasMasksClear' in constraints:
+ _constraints.append(_ConstraintHasMaskClear(constraints.get('hasMasksClear')))
+
+ if 'notValue' in constraints:
+ _constraints.append(_ConstraintNotValue(
+ constraints.get('notValue'), field_type, variable_storage))
+
+ return _constraints
diff --git a/src/controller/python/chip/yaml/parser.py b/src/controller/python/chip/yaml/parser.py
index 8fc5cd5..a3f38f7 100644
--- a/src/controller/python/chip/yaml/parser.py
+++ b/src/controller/python/chip/yaml/parser.py
@@ -18,7 +18,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from chip import ChipDeviceCtrl
-from chip.clusters.Types import NullValue
from chip.tlv import float32
import yaml
import stringcase
@@ -30,6 +29,7 @@
from .data_model_lookup import *
import chip.yaml.format_converter as Converter
from .variable_storage import VariableStorage
+from .constraints import get_constraints
_SUCCESS_STATUS_CODE = "SUCCESS"
_NODE_ID_DEFAULT = 0x12345
@@ -50,110 +50,6 @@
config_values: dict = None
-class _ConstraintValue:
- '''Constraints that are numeric primitive data types'''
-
- def __init__(self, value, field_type, context: _ExecutionContext):
- self._variable_storage = context.variable_storage
- # When not none _indirect_value_key is binding a name to the constraint value, and the
- # actual value can only be looked-up dynamically, which is why this is a key name.
- self._indirect_value_key = None
- self._value = None
-
- if value is None:
- # Default values set above is all we need here.
- return
-
- if isinstance(value, str) and self._variable_storage.is_key_saved(value):
- self._indirect_value_key = value
- else:
- self._value = Converter.convert_yaml_type(
- value, field_type)
-
- def get_value(self):
- '''Gets the current value of the constraint.
-
- This method accounts for getting the runtime saved value from DUT previous responses.
- '''
- if self._indirect_value_key:
- return self._variable_storage.load(self._indirect_value_key)
- return self._value
-
-
-class _Constraints:
- def __init__(self, constraints: dict, field_type, context: _ExecutionContext):
- self._variable_storage = context.variable_storage
- self._has_value = constraints.get('hasValue')
- self._type = constraints.get('type')
- self._starts_with = constraints.get('startsWith')
- self._ends_with = constraints.get('endsWith')
- self._is_upper_case = constraints.get('isUpperCase')
- self._is_lower_case = constraints.get('isLowerCase')
- self._min_value = _ConstraintValue(constraints.get('minValue'), field_type,
- context)
- self._max_value = _ConstraintValue(constraints.get('maxValue'), field_type,
- context)
- self._contains = constraints.get('contains')
- self._excludes = constraints.get('excludes')
- self._has_masks_set = constraints.get('hasMasksSet')
- self._has_masks_clear = constraints.get('hasMasksClear')
- self._not_value = _ConstraintValue(constraints.get('notValue'), field_type,
- context)
-
- def are_constrains_met(self, response) -> bool:
- return_value = True
-
- if self._has_value:
- logger.warn(f'HasValue constraint currently not implemented, forcing failure')
- return_value = False
-
- if self._type:
- logger.warn(f'Type constraint currently not implemented, forcing failure')
- return_value = False
-
- if self._starts_with and not response.startswith(self._starts_with):
- return_value = False
-
- if self._ends_with and not response.endswith(self._ends_with):
- return_value = False
-
- if self._is_upper_case and not response.isupper():
- return_value = False
-
- if self._is_lower_case and not response.islower():
- return_value = False
-
- min_value = self._min_value.get_value()
- if response is not NullValue and min_value and response < min_value:
- return_value = False
-
- max_value = self._max_value.get_value()
- if response is not NullValue and max_value and response > max_value:
- return_value = False
-
- if self._contains and not set(self._contains).issubset(response):
- return_value = False
-
- if self._excludes and not set(self._excludes).isdisjoint(response):
- return_value = False
-
- if self._has_masks_set:
- for mask in self._has_masks_set:
- if (response & mask) != mask:
- return_value = False
-
- if self._has_masks_clear:
- for mask in self._has_masks_clear:
- if (response & mask) != 0:
- return_value = False
-
- not_value = self._not_value.get_value()
- if not_value and response == not_value:
- return_value = False
-
- return return_value
-
-
class _VariableToSave:
def __init__(self, variable_name: str, variable_storage: VariableStorage):
self._variable_name = variable_name
@@ -311,7 +207,7 @@
'''
super().__init__(item['label'])
self._attribute_name = stringcase.pascalcase(item['attribute'])
- self._constraints = None
+ self._constraints = []
self._cluster = cluster
self._cluster_object = None
self._request_object = None
@@ -362,9 +258,9 @@
constraints = self._expected_raw_response.get('constraints')
if constraints:
- self._constraints = _Constraints(constraints,
- self._request_object.attribute_type.Type,
- context)
+ self._constraints = get_constraints(constraints,
+ self._request_object.attribute_type.Type,
+ context.variable_storage)
def run_action(self, dev_ctrl: ChipDeviceCtrl, endpoint: int, node_id: int):
try:
@@ -391,7 +287,7 @@
if self._variable_to_save is not None:
self._variable_to_save.save_response(parsed_resp)
- if self._constraints and not self._constraints.are_constrains_met(parsed_resp):
+ if not all([constraint.is_met(parsed_resp) for constraint in self._constraints]):
logger.error(f'Constraints check failed')
# TODO how should we fail the test here?