Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 46 additions & 16 deletions claasp/cipher_modules/models/cp/mzn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import subprocess
import time
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import timedelta

from minizinc import Instance, Model, Solver, Status
Expand Down Expand Up @@ -52,6 +53,18 @@
CONSTRAINT_TYPE_ERROR = "Constraint type not defined"


@dataclass
class MiniZincModelParts:
prefix: list[str] = field(default_factory=list)
variables: list[str] = field(default_factory=list)
constraints: list[str] = field(default_factory=list)
outputs: list[str] = field(default_factory=list)
carries_outputs: list[str] = field(default_factory=list)

def lines(self):
return self.prefix + self.variables + self.constraints + self.outputs + self.carries_outputs


class MznModel:

def __init__(self, cipher, sat_or_milp='sat'):
Expand Down Expand Up @@ -106,6 +119,20 @@ def initialise_model(self):
self.component_probability_var = {}
self._model_prefix = ['include "globals.mzn";', f"{usefulfunctions.MINIZINC_USEFUL_FUNCTIONS}"]

def current_model_parts(self):
return MiniZincModelParts(
variables=list(self._variables_list),
constraints=list(self._model_constraints),
outputs=list(self.mzn_output_directives),
carries_outputs=list(self.mzn_carries_output_directives),
)

def assemble_model_lines(self, parts=None):
return (parts or self.current_model_parts()).lines()

def assemble_model(self, parts=None):
return "\n".join(self.assemble_model_lines(parts)) + "\n"

def add_comment(self, comment):
"""
Write a 'comment' at the beginning of the model.
Expand Down Expand Up @@ -813,8 +840,6 @@ def solve(
):
truncated = True

mzn_model = self._variables_list + self._model_constraints

solutions = []
if solve_external:
command = self.get_command_for_solver_process(
Expand All @@ -824,15 +849,15 @@ def solve(
timeout_in_seconds_,
intermediate_solutions=intermediate_solutions_,
)
model = "\n".join(mzn_model)
model = self.assemble_model()
start = time.time()
solver_process = subprocess.run(command, input=model, capture_output=True, text=True)
end = time.time()
solve_time = end - start
if solver_process.returncode >= 0:
solver_output = solver_process.stdout.splitlines()
else:
mzn_model_string = "\n".join(mzn_model)
mzn_model_string = self.assemble_model()
solver_name_mzn = Solver.lookup(solver_name)
bit_mzn_model = Model()
bit_mzn_model.add_string(mzn_model_string)
Expand Down Expand Up @@ -972,9 +997,12 @@ def solve_for_ARX(
sage: result.statistics['nSolutions']
1
"""
constraints = self._model_constraints
variables = self._variables_list
mzn_model_string = "\n".join(constraints) + "\n".join(variables)
mzn_model_string = self.assemble_model(
MiniZincModelParts(
variables=list(self._model_constraints),
constraints=list(self._variables_list),
)
)
solver_name_mzn = Solver.lookup(solver_name)
bit_mzn_model = Model()
bit_mzn_model.add_string(mzn_model_string)
Expand Down Expand Up @@ -1077,20 +1105,22 @@ def write_minizinc_model_to_file(self, file_path, prefix=""):
- ``file_path`` -- **string**; the path of the file that will contain the model
- ``prefix`` -- **str** (default: ``)
"""
model_string = (
"\n".join(self.mzn_comments)
+ "\n".join(self._variables_list)
+ "\n".join(self._model_constraints)
+ "\n".join(self.mzn_output_directives)
+ "\n".join(self.mzn_carries_output_directives)
)
if prefix == "":
filename = f"{file_path}/{self.cipher_id}_mzn_{self.sat_or_milp}.mzn"
else:
filename = f"{file_path}/{prefix}_{self.cipher_id}_mzn_{self.sat_or_milp}.mzn"

with open(filename, "w") as file:
file.write(model_string)
file.write(
self.assemble_model(
MiniZincModelParts(
variables=self.mzn_comments + list(self._variables_list),
constraints=list(self._model_constraints),
outputs=list(self.mzn_output_directives),
carries_outputs=list(self.mzn_carries_output_directives),
)
)
)

@property
def cipher(self):
Expand Down Expand Up @@ -1152,4 +1182,4 @@ def model_variables(self):
"""
if not self._variables_list:
raise ValueError("No model generated")
return self._variables_list
return self._variables_list
36 changes: 35 additions & 1 deletion tests/unit/cipher_modules/models/cp/mzn_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher
from claasp.ciphers.block_ciphers.midori_block_cipher import MidoriBlockCipher
import claasp.cipher_modules.models.cp.mzn_model as mzn_model_module
from claasp.cipher_modules.models.cp.mzn_model import MznModel
from claasp.cipher_modules.models.cp.mzn_model import MiniZincModelParts, MznModel
from claasp.ciphers.block_ciphers.raiden_block_cipher import RaidenBlockCipher
from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model import MznXorDifferentialModel
from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model_arx_optimized import (
Expand All @@ -28,6 +28,40 @@ def total_seconds(self):
return self._seconds


def test_assemble_model_preserves_legacy_order():
speck = SpeckBlockCipher(number_of_rounds=1)
model = MznModel(speck)
model._variables_list = ["var int: x;"]
model._model_constraints = ["constraint x = 1;", "solve satisfy;"]
model.mzn_output_directives = ['output ["x=", show(x)];']

assert model.assemble_model() == (
"var int: x;\n"
"constraint x = 1;\n"
"solve satisfy;\n"
'output ["x=", show(x)];\n'
)


def test_assemble_model_accepts_explicit_parts():
speck = SpeckBlockCipher(number_of_rounds=1)
model = MznModel(speck)
parts = MiniZincModelParts(
prefix=['include "globals.mzn";'],
variables=["var int: x;"],
constraints=["constraint x = 1;", "solve satisfy;"],
outputs=['output ["x=", show(x)];'],
)

assert model.assemble_model(parts) == (
'include "globals.mzn";\n'
"var int: x;\n"
"constraint x = 1;\n"
"solve satisfy;\n"
'output ["x=", show(x)];\n'
)


@pytest.mark.parametrize(
"solver_stats,expected_time",
[
Expand Down
Loading