Coverage for colour/recovery/tests/test_otsu2018.py: 100%

268 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-15 19:01 +1300

1"""Define the unit tests for the :mod:`colour.recovery.jakob2019` module.""" 

2 

3from __future__ import annotations 

4 

5import os 

6import platform 

7import shutil 

8import tempfile 

9 

10import numpy as np 

11import pytest 

12 

13from colour.characterisation import SDS_COLOURCHECKERS 

14from colour.colorimetry import ( 

15 handle_spectral_arguments, 

16 reshape_msds, 

17 reshape_sd, 

18 sd_to_XYZ, 

19 sds_and_msds_to_msds, 

20) 

21from colour.constants import TOLERANCE_ABSOLUTE_TESTS 

22from colour.difference import delta_E_CIE1976 

23from colour.models import XYZ_to_Lab, XYZ_to_xy 

24from colour.recovery import ( 

25 SPECTRAL_SHAPE_OTSU2018, 

26 Dataset_Otsu2018, 

27 Tree_Otsu2018, 

28 XYZ_to_sd_Otsu2018, 

29) 

30from colour.recovery.otsu2018 import ( 

31 DATASET_REFERENCE_OTSU2018, 

32 Data_Otsu2018, 

33 Node_Otsu2018, 

34 PartitionAxis, 

35) 

36from colour.utilities import domain_range_scale, metric_mse 

37 

38__author__ = "Colour Developers" 

39__copyright__ = "Copyright 2013 Colour Developers" 

40__license__ = "BSD-3-Clause - https://opensource.org/licenses/BSD-3-Clause" 

41__maintainer__ = "Colour Developers" 

42__email__ = "colour-developers@colour-science.org" 

43__status__ = "Production" 

44 

45__all__ = [ 

46 "TestDataset_Otsu2018", 

47 "TestXYZ_to_sd_Otsu2018", 

48 "TestData_Otsu2018", 

49 "TestNode_Otsu2018", 

50 "TestTree_Otsu2018", 

51] 

52 

53 

54class TestDataset_Otsu2018: 

55 """ 

56 Define :class:`colour.recovery.otsu2018.Dataset_Otsu2018` definition unit 

57 tests methods. 

58 """ 

59 

60 def setup_method(self) -> None: 

61 """Initialise the common tests attributes.""" 

62 

63 self._dataset = DATASET_REFERENCE_OTSU2018 

64 self._xy = np.array([0.54369557, 0.32107944]) 

65 

66 self._temporary_directory = tempfile.mkdtemp() 

67 

68 self._path = os.path.join(self._temporary_directory, "Test_Otsu2018.npz") 

69 self._dataset.write(self._path) 

70 

71 def teardown_method(self) -> None: 

72 """After tests actions.""" 

73 

74 shutil.rmtree(self._temporary_directory) 

75 

76 def test_required_attributes(self) -> None: 

77 """Test the presence of required attributes.""" 

78 

79 required_attributes = ( 

80 "shape", 

81 "basis_functions", 

82 "means", 

83 "selector_array", 

84 ) 

85 

86 for attribute in required_attributes: 

87 assert attribute in dir(Dataset_Otsu2018) 

88 

89 def test_required_methods(self) -> None: 

90 """Test the presence of required methods.""" 

91 

92 required_methods = ( 

93 "__init__", 

94 "__str__", 

95 "select", 

96 "cluster", 

97 "read", 

98 "write", 

99 ) 

100 

101 for method in required_methods: 

102 assert method in dir(Dataset_Otsu2018) 

103 

104 def test_shape(self) -> None: 

105 """Test :attr:`colour.recovery.otsu2018.Dataset_Otsu2018.shape` property.""" 

106 

107 assert self._dataset.shape == SPECTRAL_SHAPE_OTSU2018 

108 

109 def test_basis_functions(self) -> None: 

110 """ 

111 Test :attr:`colour.recovery.otsu2018.Dataset_Otsu2018.basis_functions` 

112 property. 

113 """ 

114 

115 assert self._dataset.basis_functions is not None 

116 assert self._dataset.basis_functions.shape == (8, 3, 36) 

117 

118 def test_means(self) -> None: 

119 """ 

120 Test :attr:`colour.recovery.otsu2018.Dataset_Otsu2018.means` 

121 property. 

122 """ 

123 

124 assert self._dataset.means is not None 

125 assert self._dataset.means.shape == (8, 36) 

126 

127 def test_selector_array(self) -> None: 

128 """ 

129 Test :attr:`colour.recovery.otsu2018.Dataset_Otsu2018.selector_array` 

130 property. 

131 """ 

132 

133 assert self._dataset.selector_array is not None 

134 assert self._dataset.selector_array.shape == (7, 4) 

135 

136 def test__str__(self) -> None: 

137 """Test :meth:`colour.recovery.otsu2018.Dataset_Otsu2018.__str__` method.""" 

138 

139 assert str(self._dataset) == "Dataset_Otsu2018(8 basis functions)" 

140 

141 assert str(Dataset_Otsu2018()) == "Dataset_Otsu2018()" 

142 

143 def test_select(self) -> None: 

144 """Test :meth:`colour.recovery.otsu2018.Dataset_Otsu2018.select` method.""" 

145 

146 assert self._dataset.select(self._xy) == 6 

147 

148 def test_raise_exception_select(self) -> None: 

149 """ 

150 Test :meth:`colour.recovery.otsu2018.Dataset_Otsu2018.select` method 

151 raised exception. 

152 """ 

153 

154 pytest.raises(ValueError, Dataset_Otsu2018().select, np.array([0, 0])) 

155 

156 def test_cluster(self) -> None: 

157 """Test :meth:`colour.recovery.otsu2018.Dataset_Otsu2018.cluster` method.""" 

158 

159 basis_functions, means = self._dataset.cluster(self._xy) 

160 assert basis_functions.shape == (3, 36) 

161 assert means.shape == (36,) 

162 

163 def test_raise_exception_cluster(self) -> None: 

164 """ 

165 Test :meth:`colour.recovery.otsu2018.Dataset_Otsu2018.cluster` method 

166 raised exception. 

167 """ 

168 

169 pytest.raises(ValueError, Dataset_Otsu2018().cluster, np.array([0, 0])) 

170 

171 def test_read(self) -> None: 

172 """Test :meth:`colour.recovery.otsu2018.Dataset_Otsu2018.read` method.""" 

173 

174 dataset = Dataset_Otsu2018() 

175 dataset.read(self._path) 

176 

177 assert dataset.shape == SPECTRAL_SHAPE_OTSU2018 

178 assert dataset.basis_functions is not None 

179 assert dataset.basis_functions.shape == (8, 3, 36) 

180 assert dataset.means is not None 

181 assert dataset.means.shape == (8, 36) 

182 assert dataset.selector_array is not None 

183 assert dataset.selector_array.shape == (7, 4) 

184 

185 def test_write(self) -> None: 

186 """Test :meth:`colour.recovery.otsu2018.Dataset_Otsu2018.write` method.""" 

187 

188 self._dataset.write(self._path) 

189 

190 dataset = Dataset_Otsu2018() 

191 dataset.read(self._path) 

192 

193 assert dataset.shape == SPECTRAL_SHAPE_OTSU2018 

194 assert dataset.basis_functions is not None 

195 assert dataset.basis_functions.shape == (8, 3, 36) 

196 assert dataset.means is not None 

197 assert dataset.means.shape == (8, 36) 

198 assert dataset.selector_array is not None 

199 assert dataset.selector_array.shape == (7, 4) 

200 

201 def test_raise_exception_write(self) -> None: 

202 """ 

203 Test :meth:`colour.recovery.otsu2018.Dataset_Otsu2018.write` method 

204 raised exception. 

205 """ 

206 

207 pytest.raises(ValueError, Dataset_Otsu2018().write, "") 

208 

209 

210class TestXYZ_to_sd_Otsu2018: 

211 """ 

212 Define :func:`colour.recovery.otsu2018.XYZ_to_sd_Otsu2018` definition unit 

213 tests methods. 

214 """ 

215 

216 def setup_method(self) -> None: 

217 """Initialise the common tests attributes.""" 

218 

219 self._shape = SPECTRAL_SHAPE_OTSU2018 

220 self._cmfs, self._sd_D65 = handle_spectral_arguments(shape_default=self._shape) 

221 self._XYZ_D65 = sd_to_XYZ(self._sd_D65) 

222 self._xy_D65 = XYZ_to_xy(self._XYZ_D65) 

223 

224 def test_XYZ_to_sd_Otsu2018(self) -> None: 

225 """Test :func:`colour.recovery.otsu2018.XYZ_to_sd_Otsu2018` definition.""" 

226 

227 # Tests the round-trip with values of a colour checker. 

228 for sd in SDS_COLOURCHECKERS["ColorChecker N Ohta"].values(): 

229 XYZ = sd_to_XYZ(sd, self._cmfs, self._sd_D65) / 100 

230 Lab = XYZ_to_Lab(XYZ, self._xy_D65) 

231 

232 recovered_sd = XYZ_to_sd_Otsu2018(XYZ, self._cmfs, self._sd_D65, clip=False) 

233 recovered_XYZ = sd_to_XYZ(recovered_sd, self._cmfs, self._sd_D65) / 100 

234 recovered_Lab = XYZ_to_Lab(recovered_XYZ, self._xy_D65) 

235 

236 error = metric_mse( 

237 reshape_sd(sd, SPECTRAL_SHAPE_OTSU2018).values, 

238 recovered_sd.values, 

239 ) 

240 assert error < 0.02 

241 

242 delta_E = delta_E_CIE1976(Lab, recovered_Lab) 

243 assert delta_E < 1e-12 

244 

245 def test_raise_exception_XYZ_to_sd_Otsu2018(self) -> None: 

246 """ 

247 Test :func:`colour.recovery.otsu2018.XYZ_to_sd_Otsu2018` definition 

248 raised_exception. 

249 """ 

250 

251 pytest.raises( 

252 ValueError, 

253 XYZ_to_sd_Otsu2018, 

254 np.array([0, 0, 0]), 

255 self._cmfs, 

256 self._sd_D65, 

257 Dataset_Otsu2018(), 

258 ) 

259 

260 def test_domain_range_scale_XYZ_to_sd_Otsu2018(self) -> None: 

261 """ 

262 Test :func:`colour.recovery.otsu2018.XYZ_to_sd_Otsu2018` definition 

263 domain and range scale support. 

264 """ 

265 

266 XYZ_i = np.array([0.20654008, 0.12197225, 0.05136952]) 

267 XYZ_o = sd_to_XYZ( 

268 XYZ_to_sd_Otsu2018(XYZ_i, self._cmfs, self._sd_D65), 

269 self._cmfs, 

270 self._sd_D65, 

271 ) 

272 

273 d_r = (("reference", 1, 1), ("1", 1, 0.01), ("100", 100, 1)) 

274 for scale, factor_a, factor_b in d_r: 

275 with domain_range_scale(scale): 

276 np.testing.assert_allclose( 

277 sd_to_XYZ( 

278 XYZ_to_sd_Otsu2018(XYZ_i * factor_a, self._cmfs, self._sd_D65), 

279 self._cmfs, 

280 self._sd_D65, 

281 ), 

282 XYZ_o * factor_b, 

283 atol=TOLERANCE_ABSOLUTE_TESTS, 

284 ) 

285 

286 

287class TestData_Otsu2018: 

288 """ 

289 Define :class:`colour.recovery.otsu2018.Data_Otsu2018` definition unit 

290 tests methods. 

291 """ 

292 

293 def setup_method(self) -> None: 

294 """Initialise the common tests attributes.""" 

295 

296 self._shape = SPECTRAL_SHAPE_OTSU2018 

297 self._cmfs, self._sd_D65 = handle_spectral_arguments(shape_default=self._shape) 

298 

299 self._reflectances = np.transpose( 

300 reshape_msds( 

301 sds_and_msds_to_msds( 

302 SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() 

303 ), 

304 self._shape, 

305 ).values 

306 ) 

307 

308 self._data = Data_Otsu2018(self._reflectances, self._cmfs, self._sd_D65) 

309 

310 def test_required_attributes(self) -> None: 

311 """Test the presence of required attributes.""" 

312 

313 required_attributes = ( 

314 "reflectances", 

315 "cmfs", 

316 "illuminant", 

317 "basis_functions", 

318 "mean", 

319 ) 

320 

321 for attribute in required_attributes: 

322 assert attribute in dir(Data_Otsu2018) 

323 

324 def test_required_methods(self) -> None: 

325 """Test the presence of required methods.""" 

326 

327 required_methods = ( 

328 "__init__", 

329 "__str__", 

330 "__len__", 

331 "origin", 

332 "partition", 

333 "PCA", 

334 "reconstruct", 

335 "reconstruction_error", 

336 ) 

337 

338 for method in required_methods: 

339 assert method in dir(Data_Otsu2018) 

340 

341 def test_reflectances(self) -> None: 

342 """ 

343 Test :attr:`colour.recovery.otsu2018.Data_Otsu2018.reflectances` 

344 property. 

345 """ 

346 

347 assert self._data.reflectances is self._reflectances 

348 

349 def test_cmfs(self) -> None: 

350 """Test :attr:`colour.recovery.otsu2018.Data_Otsu2018.cmfs` property.""" 

351 

352 assert self._data.cmfs is self._cmfs 

353 

354 def test_illuminant(self) -> None: 

355 """ 

356 Test :attr:`colour.recovery.otsu2018.Data_Otsu2018.illuminant` 

357 property. 

358 """ 

359 

360 assert self._data.illuminant is self._sd_D65 

361 

362 def test_basis_functions(self) -> None: 

363 """ 

364 Test :attr:`colour.recovery.otsu2018.Data_Otsu2018.basis_functions` 

365 property. 

366 """ 

367 

368 data = Data_Otsu2018(self._reflectances, self._cmfs, self._sd_D65) 

369 

370 assert data.basis_functions is None 

371 

372 data.PCA() 

373 

374 assert data.basis_functions is not None 

375 assert data.basis_functions.shape == (3, 36) 

376 

377 def test_mean(self) -> None: 

378 """Test :attr:`colour.recovery.otsu2018.Data_Otsu2018.mean` property.""" 

379 

380 data = Data_Otsu2018(self._reflectances, self._cmfs, self._sd_D65) 

381 

382 assert data.mean is None 

383 

384 data.PCA() 

385 

386 assert data.mean is not None 

387 assert data.mean.shape == (36,) 

388 

389 def test__str__(self) -> None: 

390 """Test :meth:`colour.recovery.otsu2018.Data_Otsu2018.__str__` method.""" 

391 

392 assert str(self._data) == "Data_Otsu2018(24 Reflectances)" 

393 

394 def test__len__(self) -> None: 

395 """Test :meth:`colour.recovery.otsu2018.Data_Otsu2018.__len__` method.""" 

396 

397 assert len(self._data) == 24 

398 

399 def test_origin(self) -> None: 

400 """Test :meth:`colour.recovery.otsu2018.Data_Otsu2018.origin` method.""" 

401 

402 np.testing.assert_allclose( 

403 self._data.origin(4, 1), 

404 0.255284008578559, 

405 atol=TOLERANCE_ABSOLUTE_TESTS, 

406 ) 

407 

408 def test_raise_exception_origin(self) -> None: 

409 """ 

410 Test :meth:`colour.recovery.otsu2018.Data_Otsu2018.origin` method 

411 raised exception. 

412 """ 

413 

414 pytest.raises( 

415 ValueError, 

416 Data_Otsu2018(None, self._cmfs, self._sd_D65).origin, 

417 4, 

418 1, 

419 ) 

420 

421 def test_partition(self) -> None: 

422 """Test :meth:`colour.recovery.otsu2018.Data_Otsu2018.partition` method.""" 

423 

424 partition = self._data.partition(PartitionAxis(4, 1)) 

425 

426 assert len(partition) == 2 

427 

428 def test_raise_exception_partition(self) -> None: 

429 """ 

430 Test :meth:`colour.recovery.otsu2018.Data_Otsu2018.partition` method 

431 raised exception. 

432 """ 

433 

434 pytest.raises( 

435 ValueError, 

436 Data_Otsu2018(None, self._cmfs, self._sd_D65).partition, 

437 PartitionAxis(4, 1), 

438 ) 

439 

440 @pytest.mark.skipif( 

441 platform.system() in ("Windows", "Microsoft", "Linux"), 

442 reason="PCA tests only run on macOS", 

443 ) 

444 def test_PCA(self) -> None: 

445 """Test :meth:`colour.recovery.otsu2018.Data_Otsu2018.PCA` method.""" 

446 

447 data = Data_Otsu2018(self._reflectances, self._cmfs, self._sd_D65) 

448 

449 data.PCA() 

450 

451 assert data.basis_functions is not None 

452 

453 np.testing.assert_allclose( 

454 np.abs(data.basis_functions), 

455 np.array( 

456 [ 

457 [ 

458 0.04391241, 

459 0.08560996, 

460 0.15556120, 

461 0.20826672, 

462 0.22981218, 

463 0.23117641, 

464 0.22718022, 

465 0.21742869, 

466 0.19854261, 

467 0.16868383, 

468 0.12020268, 

469 0.05958463, 

470 0.01015508, 

471 0.08775193, 

472 0.16957532, 

473 0.23186776, 

474 0.26516404, 

475 0.27409402, 

476 0.27856619, 

477 0.27685075, 

478 0.25597708, 

479 0.21331000, 

480 0.15372029, 

481 0.08746878, 

482 0.02744494, 

483 0.01725581, 

484 0.04756055, 

485 0.07184639, 

486 0.09090063, 

487 0.10317253, 

488 0.10830387, 

489 0.10872694, 

490 0.10645999, 

491 0.10766424, 

492 0.11170078, 

493 0.11620896, 

494 ], 

495 [ 

496 0.03137588, 

497 0.06204234, 

498 0.11364884, 

499 0.17579436, 

500 0.20914074, 

501 0.22152351, 

502 0.23120105, 

503 0.24039823, 

504 0.24730359, 

505 0.25195045, 

506 0.25237533, 

507 0.24672212, 

508 0.23538236, 

509 0.22094141, 

510 0.20389065, 

511 0.18356599, 

512 0.15952882, 

513 0.13567812, 

514 0.11401807, 

515 0.09178015, 

516 0.06539517, 

517 0.03173809, 

518 0.00658524, 

519 0.04710763, 

520 0.08379987, 

521 0.11074555, 

522 0.12606191, 

523 0.13630094, 

524 0.13988107, 

525 0.14193361, 

526 0.14671866, 

527 0.15164795, 

528 0.15772737, 

529 0.16328073, 

530 0.16588768, 

531 0.16947164, 

532 ], 

533 [ 

534 0.01360289, 

535 0.02375832, 

536 0.04262545, 

537 0.07345243, 

538 0.09081235, 

539 0.09227928, 

540 0.08922710, 

541 0.08626299, 

542 0.08584571, 

543 0.08843734, 

544 0.09475094, 

545 0.10376740, 

546 0.11331399, 

547 0.12109706, 

548 0.12678070, 

549 0.13401030, 

550 0.14417036, 

551 0.15408359, 

552 0.16265529, 

553 0.17079814, 

554 0.17972656, 

555 0.19005983, 

556 0.20053986, 

557 0.21017531, 

558 0.21808806, 

559 0.22347400, 

560 0.22650876, 

561 0.22895376, 

562 0.22982598, 

563 0.23001787, 

564 0.23036398, 

565 0.22917409, 

566 0.22684271, 

567 0.22387883, 

568 0.22065773, 

569 0.21821049, 

570 ], 

571 ] 

572 ), 

573 atol=TOLERANCE_ABSOLUTE_TESTS, 

574 ) 

575 

576 assert data.mean is not None 

577 

578 np.testing.assert_allclose( 

579 data.mean, 

580 np.array( 

581 [ 

582 0.08795833, 

583 0.12050000, 

584 0.16787500, 

585 0.20675000, 

586 0.22329167, 

587 0.22837500, 

588 0.23229167, 

589 0.23579167, 

590 0.23658333, 

591 0.23779167, 

592 0.23866667, 

593 0.23975000, 

594 0.24345833, 

595 0.25054167, 

596 0.25791667, 

597 0.26150000, 

598 0.26437500, 

599 0.26566667, 

600 0.26475000, 

601 0.26554167, 

602 0.27137500, 

603 0.28279167, 

604 0.29529167, 

605 0.31070833, 

606 0.32575000, 

607 0.33829167, 

608 0.34675000, 

609 0.35554167, 

610 0.36295833, 

611 0.37004167, 

612 0.37854167, 

613 0.38675000, 

614 0.39587500, 

615 0.40266667, 

616 0.40683333, 

617 0.41287500, 

618 ] 

619 ), 

620 atol=TOLERANCE_ABSOLUTE_TESTS, 

621 ) 

622 

623 def test_reconstruct(self) -> None: 

624 """ 

625 Test :meth:`colour.recovery.otsu2018.Data_Otsu2018.reconstruct` 

626 method. 

627 """ 

628 

629 data = Data_Otsu2018(self._reflectances, self._cmfs, self._sd_D65) 

630 

631 data.PCA() 

632 

633 np.testing.assert_allclose( 

634 data.reconstruct( 

635 np.array( 

636 [ 

637 0.20654008, 

638 0.12197225, 

639 0.05136952, 

640 ] 

641 ) 

642 ).values, 

643 np.array( 

644 [ 

645 0.06899964, 

646 0.08241919, 

647 0.09768650, 

648 0.08938555, 

649 0.07872582, 

650 0.07140930, 

651 0.06385099, 

652 0.05471747, 

653 0.04281364, 

654 0.03073280, 

655 0.01761134, 

656 0.00772535, 

657 0.00379120, 

658 0.00405617, 

659 0.00595014, 

660 0.01323536, 

661 0.03229711, 

662 0.05661531, 

663 0.07763041, 

664 0.10271461, 

665 0.14276781, 

666 0.20239859, 

667 0.27288559, 

668 0.35044541, 

669 0.42170481, 

670 0.47567859, 

671 0.50910276, 

672 0.53578140, 

673 0.55251101, 

674 0.56530032, 

675 0.58029915, 

676 0.59367723, 

677 0.60830542, 

678 0.62100871, 

679 0.62881635, 

680 0.63971254, 

681 ] 

682 ), 

683 atol=TOLERANCE_ABSOLUTE_TESTS, 

684 ) 

685 

686 def test_raise_exception_reconstruct(self) -> None: 

687 """ 

688 Test :meth:`colour.recovery.otsu2018.Data_Otsu2018.reconstruct` method 

689 raised exception. 

690 """ 

691 

692 pytest.raises( 

693 ValueError, 

694 Data_Otsu2018(None, self._cmfs, self._sd_D65).reconstruct, 

695 np.array([0, 0, 0]), 

696 ) 

697 

698 def test_reconstruction_error(self) -> None: 

699 """ 

700 Test :meth:`colour.recovery.otsu2018.Data_Otsu2018.\ 

701reconstruction_error` method. 

702 """ 

703 

704 data = Data_Otsu2018(self._reflectances, self._cmfs, self._sd_D65) 

705 

706 np.testing.assert_allclose( 

707 data.reconstruction_error(), 

708 2.753352549148681, 

709 atol=TOLERANCE_ABSOLUTE_TESTS, 

710 ) 

711 

712 def test_raise_exception_reconstruction_error(self) -> None: 

713 """ 

714 Test :meth:`colour.recovery.otsu2018.Data_Otsu2018.\ 

715reconstruction_error` method raised exception. 

716 """ 

717 

718 pytest.raises( 

719 ValueError, 

720 Data_Otsu2018(None, self._cmfs, self._sd_D65).reconstruction_error, 

721 ) 

722 

723 

724class TestNode_Otsu2018: 

725 """ 

726 Define :class:`colour.recovery.otsu2018.Node_Otsu2018` definition unit 

727 tests methods. 

728 """ 

729 

730 def setup_method(self) -> None: 

731 """Initialise the common tests attributes.""" 

732 

733 self._shape = SPECTRAL_SHAPE_OTSU2018 

734 self._cmfs, self._sd_D65 = handle_spectral_arguments(shape_default=self._shape) 

735 

736 self._reflectances = sds_and_msds_to_msds( 

737 SDS_COLOURCHECKERS["ColorChecker N Ohta"].values() 

738 ) 

739 

740 self._tree = Tree_Otsu2018(self._reflectances) 

741 self._tree.optimise() 

742 for leaf in self._tree.leaves: 

743 if len(leaf.parent.children) == 2: 

744 self._node_a = leaf.parent 

745 self._node_b, self._node_c = self._node_a.children 

746 break 

747 

748 self._data_a = Data_Otsu2018( 

749 np.transpose(reshape_msds(self._reflectances, self._shape).values), 

750 self._cmfs, 

751 self._sd_D65, 

752 ) 

753 self._data_b = self._node_b.data 

754 

755 self._partition_axis = self._node_a.partition_axis 

756 

757 def test_required_attributes(self) -> None: 

758 """Test the presence of required attributes.""" 

759 

760 required_attributes = ("partition_axis", "row") 

761 

762 for attribute in required_attributes: 

763 assert attribute in dir(Node_Otsu2018) 

764 

765 def test_required_methods(self) -> None: 

766 """Test the presence of required methods.""" 

767 

768 required_methods = ( 

769 "__init__", 

770 "split", 

771 "minimise", 

772 "leaf_reconstruction_error", 

773 "branch_reconstruction_error", 

774 ) 

775 

776 for method in required_methods: 

777 assert method in dir(Node_Otsu2018) 

778 

779 def test_partition_axis(self) -> None: 

780 """ 

781 Test :attr:`colour.recovery.otsu2018.Node_Otsu2018.partition_axis` 

782 property. 

783 """ 

784 

785 assert self._node_a.partition_axis is self._partition_axis 

786 

787 def test_row(self) -> None: 

788 """Test :attr:`colour.recovery.otsu2018.Node_Otsu2018.row` property.""" 

789 

790 assert self._node_a.row == ( 

791 self._partition_axis.origin, 

792 self._partition_axis.direction, 

793 self._node_b, 

794 self._node_c, 

795 ) 

796 

797 def test_raise_exception_row(self) -> None: 

798 """ 

799 Test :attr:`colour.recovery.otsu2018.Node_Otsu2018.row` property 

800 raised exception. 

801 """ 

802 

803 pytest.raises(ValueError, lambda: Node_Otsu2018().row) 

804 

805 def test_split(self) -> None: 

806 """Test :meth:`colour.recovery.otsu2018.Node_Otsu2018.split` method.""" 

807 

808 node_a = Node_Otsu2018(self._tree, None) 

809 node_b = Node_Otsu2018(self._tree, data=self._data_a) 

810 node_c = Node_Otsu2018(self._tree, data=self._data_a) 

811 node_a.split([node_b, node_c], PartitionAxis(12, 0)) 

812 

813 assert len(node_a.children) == 2 

814 

815 def test_minimise(self) -> None: 

816 """Test :meth:`colour.recovery.otsu2018.Node_Otsu2018.minimise` method.""" 

817 

818 node = Node_Otsu2018(data=self._data_a) 

819 partition, axis, partition_error = node.minimise(3) 

820 

821 assert (len(partition[0].data), len(partition[1].data)) == (10, 14) 

822 

823 np.testing.assert_allclose( 

824 axis.origin, 0.324111380117147, atol=TOLERANCE_ABSOLUTE_TESTS 

825 ) 

826 np.testing.assert_allclose( 

827 partition_error, 2.0402980027, atol=TOLERANCE_ABSOLUTE_TESTS 

828 ) 

829 

830 def test_leaf_reconstruction_error(self) -> None: 

831 """ 

832 Test :meth:`colour.recovery.otsu2018.Node_Otsu2018.\ 

833leaf_reconstruction_error` method. 

834 """ 

835 

836 np.testing.assert_allclose( 

837 self._node_b.leaf_reconstruction_error(), 

838 1.145340908277367e-29, 

839 atol=TOLERANCE_ABSOLUTE_TESTS, 

840 ) 

841 

842 def test_branch_reconstruction_error(self) -> None: 

843 """ 

844 Test :meth:`colour.recovery.otsu2018.Node_Otsu2018.\ 

845branch_reconstruction_error` method. 

846 """ 

847 

848 np.testing.assert_allclose( 

849 self._node_a.branch_reconstruction_error(), 

850 3.900015991807948e-25, 

851 atol=TOLERANCE_ABSOLUTE_TESTS, 

852 ) 

853 

854 

855class TestTree_Otsu2018: 

856 """ 

857 Define :class:`colour.recovery.otsu2018.Tree_Otsu2018` definition unit 

858 tests methods. 

859 """ 

860 

861 def setup_method(self) -> None: 

862 """Initialise the common tests attributes.""" 

863 

864 self._shape = SPECTRAL_SHAPE_OTSU2018 

865 self._cmfs, self._sd_D65 = handle_spectral_arguments(shape_default=self._shape) 

866 

867 self._reflectances = sds_and_msds_to_msds( 

868 list(SDS_COLOURCHECKERS["ColorChecker N Ohta"].values()) 

869 + list(SDS_COLOURCHECKERS["BabelColor Average"].values()) 

870 ) 

871 

872 self._tree = Tree_Otsu2018(self._reflectances, self._cmfs, self._sd_D65) 

873 

874 self._XYZ_D65 = sd_to_XYZ(self._sd_D65) 

875 self._xy_D65 = XYZ_to_xy(self._XYZ_D65) 

876 

877 self._temporary_directory = tempfile.mkdtemp() 

878 

879 self._path = os.path.join(self._temporary_directory, "Test_Otsu2018.npz") 

880 

881 def teardown_method(self) -> None: 

882 """After tests actions.""" 

883 

884 shutil.rmtree(self._temporary_directory) 

885 

886 def test_required_attributes(self) -> None: 

887 """Test the presence of required attributes.""" 

888 

889 required_attributes = ("reflectances", "cmfs", "illuminant") 

890 

891 for attribute in required_attributes: 

892 assert attribute in dir(Tree_Otsu2018) 

893 

894 def test_required_methods(self) -> None: 

895 """Test the presence of required methods.""" 

896 

897 required_methods = ("__init__", "__str__", "optimise", "to_dataset") 

898 

899 for method in required_methods: 

900 assert method in dir(Tree_Otsu2018) 

901 

902 def test_reflectances(self) -> None: 

903 """ 

904 Test :attr:`colour.recovery.otsu2018.Tree_Otsu2018.reflectances` 

905 property. 

906 """ 

907 

908 np.testing.assert_allclose( 

909 self._tree.reflectances, 

910 np.transpose( 

911 reshape_msds( 

912 sds_and_msds_to_msds(self._reflectances), self._shape 

913 ).values 

914 ), 

915 atol=TOLERANCE_ABSOLUTE_TESTS, 

916 ) 

917 

918 def test_cmfs(self) -> None: 

919 """Test :attr:`colour.recovery.otsu2018.Tree_Otsu2018.cmfs` property.""" 

920 

921 assert self._tree.cmfs is self._cmfs 

922 

923 def test_illuminant(self) -> None: 

924 """ 

925 Test :attr:`colour.recovery.otsu2018.Tree_Otsu2018.illuminant` 

926 property. 

927 """ 

928 

929 assert self._tree.illuminant is self._sd_D65 

930 

931 def test_optimise(self) -> None: 

932 """Test :class:`colour.recovery.otsu2018.Tree_Otsu2018.optimise` method.""" 

933 

934 node_tree = Tree_Otsu2018(self._reflectances, self._cmfs, self._sd_D65) 

935 node_tree.optimise(iterations=5) 

936 

937 dataset = node_tree.to_dataset() 

938 dataset.write(self._path) 

939 

940 dataset = Dataset_Otsu2018() 

941 dataset.read(self._path) 

942 

943 for sd in SDS_COLOURCHECKERS["ColorChecker N Ohta"].values(): 

944 XYZ = sd_to_XYZ(sd, self._cmfs, self._sd_D65) / 100 

945 Lab = XYZ_to_Lab(XYZ, self._xy_D65) 

946 

947 recovered_sd = XYZ_to_sd_Otsu2018( 

948 XYZ, self._cmfs, self._sd_D65, dataset, False 

949 ) 

950 recovered_XYZ = sd_to_XYZ(recovered_sd, self._cmfs, self._sd_D65) / 100 

951 recovered_Lab = XYZ_to_Lab(recovered_XYZ, self._xy_D65) 

952 

953 error = metric_mse( 

954 reshape_sd(sd, SPECTRAL_SHAPE_OTSU2018).values, 

955 recovered_sd.values, 

956 ) 

957 assert error < 0.075 

958 

959 delta_E = delta_E_CIE1976(Lab, recovered_Lab) 

960 assert delta_E < 1e-12 

961 

962 def test_to_dataset(self) -> None: 

963 """ 

964 Test :attr:`colour.recovery.otsu2018.Tree_Otsu2018.to_dataset` 

965 method. 

966 """ 

967 

968 node_tree = Tree_Otsu2018(self._reflectances, self._cmfs, self._sd_D65) 

969 dataset = node_tree.to_dataset() 

970 dataset.write(self._path)