diff --git a/medcat-v2/medcat/utils/cdb_utils.py b/medcat-v2/medcat/utils/cdb_utils.py new file mode 100644 index 00000000..409de5ae --- /dev/null +++ b/medcat-v2/medcat/utils/cdb_utils.py @@ -0,0 +1,292 @@ +from collections import defaultdict +import logging +import numpy as np + +from copy import deepcopy +from medcat.cdb import CDB + +logger = logging.getLogger(__name__) # separate logger from the package-level one + + +def merge_cdb(cdb1: CDB, + cdb2: CDB, + overwrite_training: int = 0, + full_build: bool = False) -> CDB: + """Merge two CDB's together to produce a new, single CDB. The contents of + inputs CDBs will not be changed. + `addl_info` can not be perfectly merged, and will prioritise cdb1. see `full_build` + + Args: + cdb1 (CDB): + The first medcat cdb to merge. In cases where merging isn't suitable + isn't ideal (such as cui2preferred_name), this cdb values will be + prioritised over cdb2. + cdb2 (CDB): + The second medcat cdb to merge. + overwrite_training (int): + Choose to prioritise a CDB's context vectors values over merging gracefully. + 0 - no prio, 1 - CDB1, 2 - CDB2 + full_build (bool): + Add additional information from "addl_info" dicts "cui2ontologies" and + "cui2description" + + Returns: + CDB: The merged CDB. + """ + config = deepcopy(cdb1.config) + cdb = CDB(config) + + # Copy CDB 1 - as all settings from CDB 1 will be carried over + cdb.cui2info = deepcopy(cdb1.cui2info) + cdb.name2info = deepcopy(cdb1.name2info) + cdb.type_id2info = deepcopy(cdb1.type_id2info) + cdb.token_counts = deepcopy(cdb1.token_counts) + cdb._subnames = deepcopy(cdb1._subnames) + if full_build: + cdb.addl_info = deepcopy(cdb1.addl_info) + + # Merge concepts from cdb2 into the merged CDB + for cui, cui_info2 in cdb2.cui2info.items(): + # Get name status from cdb2 + name_status = 'A' # default status + for name in cui_info2['names']: + if name in cdb2.name2info: + name_info = cdb2.name2info[name] + if cui in name_info['per_cui_status']: + name_status = name_info['per_cui_status'][cui] + break + + # Prepare names dict for _add_concept + names = {} + for name in cui_info2['names']: + # Create a simple NameDescriptor-like structure + names[name] = type('NameDescriptor', (), { + 'snames': cui_info2['subnames'], + 'is_upper': cdb2.name2info.get(name, {}).get('is_upper', False), + 'tokens': set(), # We don't have token info in the new structure + 'raw_name': name + })() + + # Get ontologies and description for full_build + ontologies = set() + description = cui_info2.get('description', '') + to_build = full_build and ( + cui_info2.get('original_names') is not None or + cui_info2.get('description') is not None + ) + + if to_build and cui_info2.get('in_other_ontology'): + ontologies.update(cui_info2['in_other_ontology']) + + cdb._add_concept( + cui=cui, names=names, ontologies=ontologies, name_status=name_status, + type_ids=cui_info2['type_ids'], description=description, + full_build=to_build + ) + + # Copy training data from cdb2 for concepts that don't exist in cdb1 + if cui not in cdb1.cui2info: + cui_info_merged = cdb.cui2info[cui] + cui_info_merged['count_train'] = cui_info2['count_train'] + cui_info_merged['context_vectors'] = deepcopy(cui_info2['context_vectors']) + cui_info_merged['average_confidence'] = cui_info2['average_confidence'] + if cui_info2.get('tags'): + cui_info_merged['tags'] = deepcopy(cui_info2['tags']) + + # Handle merging of training data for concepts that exist in both CDBs + if cui in cdb1.cui2info: + cui_info1 = cdb1.cui2info[cui] + cui_info_merged = cdb.cui2info[cui] + + # Merge count_train + if (cui_info1['count_train'] > 0 or cui_info2['count_train'] > 0) and not ( + overwrite_training == 1 and cui_info1['count_train'] > 0 + ): + if overwrite_training == 2 and cui_info2['count_train'] > 0: + cui_info_merged['count_train'] = cui_info2['count_train'] + else: + cui_info_merged['count_train'] = ( + cui_info1['count_train'] + cui_info2['count_train'] + ) + + # Merge context vectors + if (cui_info1['context_vectors'] is not None and + not (overwrite_training == 1 and + cui_info1['context_vectors'] is not None)): + + if (overwrite_training == 2 and + cui_info2['context_vectors'] is not None): + cui_info_merged['context_vectors'] = deepcopy( + cui_info2['context_vectors'] + ) + else: + # Merge context vectors with weighted average + if cui_info_merged['context_vectors'] is None: + cui_info_merged['context_vectors'] = {} + + # Get all context types from both CDBs + contexts = set() + if cui_info1['context_vectors']: + contexts.update(cui_info1['context_vectors'].keys()) + if cui_info2['context_vectors']: + contexts.update(cui_info2['context_vectors'].keys()) + + # Calculate weights + if overwrite_training == 2: + weights = [0, 1] + else: + norm = cui_info_merged['count_train'] + if norm > 0: + weights = [ + np.divide(cui_info1['count_train'], norm), + np.divide(cui_info2['count_train'], norm) + ] + else: + weights = [0.5, 0.5] # equal weights if no training + + # Merge each context vector + for context_type in contexts: + if cui_info1['context_vectors']: + vec1 = cui_info1['context_vectors'].get( + context_type, np.zeros(300) + ) + else: + vec1 = np.zeros(300) + + if cui_info2['context_vectors']: + vec2 = cui_info2['context_vectors'].get( + context_type, np.zeros(300) + ) + else: + vec2 = np.zeros(300) + cui_info_merged['context_vectors'][context_type] = ( + weights[0] * vec1 + weights[1] * vec2 + ) + + # Merge tags + if cui_info1.get('tags') and cui_info2.get('tags'): + if cui_info_merged['tags'] is None: + cui_info_merged['tags'] = [] + cui_info_merged['tags'].extend(cui_info2['tags']) + + # Merge type_ids (already handled by _add_concept, but ensure union) + cui_info_merged['type_ids'].update(cui_info2['type_ids']) + + # Merge name training counts + if overwrite_training != 1: + for name, name_info2 in cdb2.name2info.items(): + if name in cdb1.name2info and overwrite_training == 0: + # Merge training counts for names that exist in both CDBs + name_info1 = cdb1.name2info[name] + name_info_merged = cdb.name2info[name] + name_info_merged['count_train'] = ( + name_info1['count_train'] + name_info2['count_train'] + ) + else: + # Copy name info from cdb2 if it doesn't exist in cdb1 + if name not in cdb.name2info: + cdb.name2info[name] = deepcopy(name_info2) + + # Merge token counts + if overwrite_training != 1: + for token, count in cdb2.token_counts.items(): + if token in cdb.token_counts and overwrite_training == 0: + cdb.token_counts[token] += count + else: + cdb.token_counts[token] = count + + return cdb + + +def _dedupe_preserve_order(items: list[str]) -> list[str]: + seen = set() + deduped_list = [] + for item in items: + if item not in seen: + seen.add(item) + deduped_list.append(item) + return deduped_list + + +def get_all_ch(parent_cui: str, cdb): + """Get all the children of a given parent CUI. Preserves the order of the parent + + Args: + parent_cui (str): The parent CUI + cdb (CDB): The CDB object + + Returns: + list: The children of the parent CUI + """ + all_ch = [parent_cui] + for cui in cdb.addl_info.get('pt2ch', {}).get(parent_cui, []): + cui_chs = get_all_ch(cui, cdb) + all_ch += cui_chs + return _dedupe_preserve_order(all_ch) + + +def ch2pt_from_pt2ch(cdb: CDB, pt2ch_key: str = 'pt2ch'): + """Get the child to parent info from the pt2ch map in the CDB + + Args: + cdb (CDB): The CDB object with addl_info['pt2ch'] + pt2ch_key (str, optional): The key in the addl_info dict to get the pt2ch map + from. + Defaults to 'pt2ch'. + Returns: + dict: The child to parent info + """ + ch2pt = defaultdict(list) + for k, vals in cdb.addl_info[pt2ch_key].items(): + for v in vals: + ch2pt[v].append(k) + return ch2pt + + +def snomed_ct_concept_path( + cui: str, cdb: CDB, parent_node='138875005' +) -> dict[str, list[dict]]: + """Get the concept path for a given CUI to a parent node + + Args: + cui (str): The CUI of the concept to get the path for + cdb (CDB): The CDB object + parent_node (str, optional): The top level parent node. + Defaults to '138875005' the root SNOMED CT code. + + Returns: + dict: The concept path and links + """ + try: + def find_parents(cui, cuis2nodes, child_node=None): + parents = list(cdb.addl_info.get('ch2pt', {}).get(cui, [])) + all_links = [] + if cui not in cuis2nodes: + # Get preferred name from the new CDB structure + preferred_name = cdb.cui2info.get(cui, {}).get('preferred_name', cui) + curr_node = {'cui': cui, 'pretty_name': preferred_name} + if child_node: + curr_node['children'] = [child_node] + cuis2nodes[cui] = curr_node + if len(parents) > 0: + all_links += find_parents( + parents[0], cuis2nodes, child_node=curr_node + ) + for p in parents[1:]: + links = find_parents(p, cuis2nodes) + all_links += [{'parent': p, 'child': cui}] + links + else: + if child_node: + if 'children' not in cuis2nodes[cui]: + cuis2nodes[cui]['children'] = [] + cuis2nodes[cui]['children'].append(child_node) + return all_links + cuis2nodes = dict() + all_links = find_parents(cui, cuis2nodes) + return { + 'node_path': cuis2nodes[parent_node], + 'links': all_links + } + except KeyError as e: + logger.warning(f'Cannot find path concept path:{e}') + return [] diff --git a/medcat-v2/tests/utils/test_cdb_utils.py b/medcat-v2/tests/utils/test_cdb_utils.py new file mode 100644 index 00000000..d341c71c --- /dev/null +++ b/medcat-v2/tests/utils/test_cdb_utils.py @@ -0,0 +1,454 @@ +import unittest +import numpy as np + +from medcat.cdb import CDB +from medcat.config import Config +from medcat.cdb.concepts import get_new_cui_info, get_new_name_info +from medcat.utils.cdb_utils import ( + merge_cdb, _dedupe_preserve_order, get_all_ch, + ch2pt_from_pt2ch, snomed_ct_concept_path +) + + +class CDBUtilsTests(unittest.TestCase): + """Test cases for medcat.utils.cdb_utils module.""" + + @classmethod + def setUpClass(cls): + """Set up test fixtures.""" + cls.config = Config() + cls.config.general.log_level = 20 # INFO level + + def setUp(self): + """Set up for each test.""" + self.cdb1 = self._create_test_cdb("cdb1") + self.cdb2 = self._create_test_cdb("cdb2") + + def _create_test_cdb(self, name: str) -> CDB: + """Create a test CDB with sample data.""" + cdb = CDB(self.config) + + if name == "cdb1": + # CUI1 with training data + cui_info1 = get_new_cui_info( + cui="CUI1", + preferred_name="Test Concept 1", + names={"test concept 1", "tc1"}, + subnames={"test", "concept", "tc1"}, + type_ids={"T001"}, + description="First test concept", + count_train=10, + context_vectors={ + "long": np.random.rand(300), + "short": np.random.rand(300) + }, + average_confidence=0.8 + ) + cdb.cui2info["CUI1"] = cui_info1 + + # CUI2 without training data + cui_info2 = get_new_cui_info( + cui="CUI2", + preferred_name="Test Concept 2", + names={"test concept 2", "tc2"}, + subnames={"test", "concept", "tc2"}, + type_ids={"T002"}, + description="Second test concept", + count_train=0, + context_vectors=None, + average_confidence=0.0 + ) + cdb.cui2info["CUI2"] = cui_info2 + + # Add name info + name_info1 = get_new_name_info( + name="test concept 1", + per_cui_status={"CUI1": "A"}, + is_upper=False, + count_train=5 + ) + cdb.name2info["test concept 1"] = name_info1 + + elif name == "cdb2": + # CUI1 with different training data (should be merged) + cui_info1 = get_new_cui_info( + cui="CUI1", + preferred_name="Test Concept 1", + names={"test concept 1", "tc1", "concept one"}, + subnames={"test", "concept", "tc1", "one"}, + type_ids={"T001", "T003"}, + description="First test concept (updated)", + count_train=15, + context_vectors={ + "long": np.random.rand(300), + "medium": np.random.rand(300) + }, + average_confidence=0.9 + ) + cdb.cui2info["CUI1"] = cui_info1 + + # CUI3 (new concept) + cui_info3 = get_new_cui_info( + cui="CUI3", + preferred_name="Test Concept 3", + names={"test concept 3", "tc3"}, + subnames={"test", "concept", "tc3"}, + type_ids={"T003"}, + description="Third test concept", + count_train=5, + context_vectors={"short": np.random.rand(300)}, + average_confidence=0.7 + ) + cdb.cui2info["CUI3"] = cui_info3 + + # Add name info + name_info1 = get_new_name_info( + name="test concept 1", + per_cui_status={"CUI1": "P"}, + is_upper=False, + count_train=8 + ) + cdb.name2info["test concept 1"] = name_info1 + + # Add token counts + cdb.token_counts["test"] = 10 + cdb.token_counts["concept"] = 15 + + return cdb + + def test_merge_cdb_basic_merge(self): + """Test basic CDB merging functionality.""" + merged_cdb = merge_cdb( + self.cdb1, self.cdb2, overwrite_training=0, full_build=True + ) + + # Should have 3 concepts total + self.assertEqual(len(merged_cdb.cui2info), 3) + + # CUI1 should be merged + self.assertIn("CUI1", merged_cdb.cui2info) + cui_info = merged_cdb.cui2info["CUI1"] + self.assertEqual(cui_info['count_train'], 25) # 10 + 15 + self.assertIn("concept one", cui_info['names']) + self.assertIn("T003", cui_info['type_ids']) + self.assertIsNotNone(cui_info['context_vectors']) + self.assertIn("medium", cui_info['context_vectors']) + + # CUI3 should be added + self.assertIn("CUI3", merged_cdb.cui2info) + cui_info3 = merged_cdb.cui2info["CUI3"] + self.assertEqual(cui_info3['count_train'], 5) + self.assertIsNotNone(cui_info3['context_vectors']) + + # CUI2 should be preserved + self.assertIn("CUI2", merged_cdb.cui2info) + cui_info2 = merged_cdb.cui2info["CUI2"] + self.assertEqual(cui_info2['count_train'], 0) + + def test_merge_cdb_overwrite_training_cdb1(self): + """Test CDB merging with overwrite_training=1 (prioritize cdb1).""" + merged_cdb = merge_cdb( + self.cdb1, self.cdb2, overwrite_training=1, full_build=True + ) + + # CUI1 should keep cdb1's training data + cui_info = merged_cdb.cui2info["CUI1"] + self.assertEqual(cui_info['count_train'], 10) # Only cdb1's count + self.assertIn("long", cui_info['context_vectors']) + self.assertIn("short", cui_info['context_vectors']) + # Should not have medium from cdb2 + self.assertNotIn("medium", cui_info['context_vectors']) + + def test_merge_cdb_overwrite_training_cdb2(self): + """Test CDB merging with overwrite_training=2 (prioritize cdb2).""" + merged_cdb = merge_cdb( + self.cdb1, self.cdb2, overwrite_training=2, full_build=True + ) + + # CUI1 should use cdb2's training data + cui_info = merged_cdb.cui2info["CUI1"] + self.assertEqual(cui_info['count_train'], 15) # Only cdb2's count + self.assertIn("long", cui_info['context_vectors']) + self.assertIn("medium", cui_info['context_vectors']) + # Should not have short from cdb1 + self.assertNotIn("short", cui_info['context_vectors']) + + def test_merge_cdb_name_info_merging(self): + """Test that name information is properly merged.""" + merged_cdb = merge_cdb( + self.cdb1, self.cdb2, overwrite_training=0, full_build=True + ) + + # Name info should be merged + name_info = merged_cdb.name2info["test concept 1"] + self.assertEqual(name_info['count_train'], 13) # 5 + 8 + # Should use cdb2's status (P) since it's more recent + self.assertEqual(name_info['per_cui_status']['CUI1'], 'P') + + def test_merge_cdb_token_counts(self): + """Test that token counts are properly merged.""" + merged_cdb = merge_cdb( + self.cdb1, self.cdb2, overwrite_training=0, full_build=True + ) + + # Token counts should be merged + self.assertEqual(merged_cdb.token_counts['test'], 20) # 10 + 10 + self.assertEqual(merged_cdb.token_counts['concept'], 30) # 15 + 15 + + def test_merge_cdb_preserves_original_cdbs(self): + """Test that original CDBs are not modified.""" + original_cdb1_cui1_count = self.cdb1.cui2info["CUI1"]['count_train'] + original_cdb2_cui1_count = self.cdb2.cui2info["CUI1"]['count_train'] + + merged_cdb = merge_cdb( + self.cdb1, self.cdb2, overwrite_training=0, full_build=True + ) + + # Original CDBs should be unchanged + self.assertEqual( + self.cdb1.cui2info["CUI1"]['count_train'], original_cdb1_cui1_count + ) + self.assertEqual( + self.cdb2.cui2info["CUI1"]['count_train'], original_cdb2_cui1_count + ) + + def test_merge_cdb_empty_cdb2(self): + """Test merging with an empty cdb2.""" + empty_cdb = CDB(self.config) + merged_cdb = merge_cdb( + self.cdb1, empty_cdb, overwrite_training=0, full_build=True + ) + + # Should be identical to cdb1 + self.assertEqual(len(merged_cdb.cui2info), len(self.cdb1.cui2info)) + self.assertEqual(merged_cdb.cui2info["CUI1"]['count_train'], 10) + + def test_merge_cdb_empty_cdb1(self): + """Test merging with an empty cdb1.""" + empty_cdb = CDB(self.config) + merged_cdb = merge_cdb( + empty_cdb, self.cdb2, overwrite_training=0, full_build=True + ) + + # Should be identical to cdb2 + self.assertEqual(len(merged_cdb.cui2info), len(self.cdb2.cui2info)) + self.assertEqual(merged_cdb.cui2info["CUI1"]['count_train'], 15) + + def test_dedupe_preserve_order(self): + """Test the _dedupe_preserve_order function.""" + # Test with duplicates + items = ["a", "b", "a", "c", "b", "d"] + result = _dedupe_preserve_order(items) + expected = ["a", "b", "c", "d"] + self.assertEqual(result, expected) + + # Test with no duplicates + items = ["a", "b", "c", "d"] + result = _dedupe_preserve_order(items) + self.assertEqual(result, items) + + # Test with empty list + result = _dedupe_preserve_order([]) + self.assertEqual(result, []) + + # Test with single item + result = _dedupe_preserve_order(["a"]) + self.assertEqual(result, ["a"]) + + def test_get_all_ch(self): + """Test the get_all_ch function.""" + # Create a CDB with parent-child relationships + cdb = CDB(self.config) + cdb.addl_info = { + 'pt2ch': { + 'parent1': ['child1', 'child2'], + 'child1': ['grandchild1'], + 'child2': ['grandchild2'], + 'grandchild1': [] + } + } + + # Test getting all children of parent1 + all_children = get_all_ch('parent1', cdb) + expected = ['parent1', 'child1', 'grandchild1', 'child2', 'grandchild2'] + self.assertEqual(all_children, expected) + + # Test getting children of a leaf node + all_children = get_all_ch('grandchild1', cdb) + self.assertEqual(all_children, ['grandchild1']) + + # Test getting children of a node with no children + all_children = get_all_ch('nonexistent', cdb) + self.assertEqual(all_children, ['nonexistent']) + + def test_ch2pt_from_pt2ch(self): + """Test the ch2pt_from_pt2ch function.""" + # Create a CDB with parent-child relationships + cdb = CDB(self.config) + cdb.addl_info = { + 'pt2ch': { + 'parent1': ['child1', 'child2'], + 'parent2': ['child1', 'child3'], + 'child1': ['grandchild1'] + } + } + + # Test default key + ch2pt = ch2pt_from_pt2ch(cdb) + expected = { + 'child1': ['parent1', 'parent2'], + 'child2': ['parent1'], + 'child3': ['parent2'], + 'grandchild1': ['child1'] + } + self.assertEqual(ch2pt, expected) + + # Test with custom key + cdb.addl_info['custom_pt2ch'] = { + 'root': ['branch1', 'branch2'] + } + ch2pt = ch2pt_from_pt2ch(cdb, 'custom_pt2ch') + expected = { + 'branch1': ['root'], + 'branch2': ['root'] + } + self.assertEqual(ch2pt, expected) + + def test_snomed_ct_concept_path(self): + """Test the snomed_ct_concept_path function.""" + # Create a CDB with SNOMED CT hierarchy + cdb = CDB(self.config) + cdb.addl_info = { + 'ch2pt': { + '138875005': [], # Root + '123456789': ['138875005'], # Child of root + '987654321': ['123456789'], # Grandchild + '111222333': ['987654321'] # Great-grandchild + } + } + + # Add CUI info for the concepts + cdb.cui2info = { + '138875005': get_new_cui_info( + cui='138875005', preferred_name='SNOMED CT Root' + ), + '123456789': get_new_cui_info( + cui='123456789', preferred_name='Clinical Finding' + ), + '987654321': get_new_cui_info(cui='987654321', preferred_name='Disease'), + '111222333': get_new_cui_info(cui='111222333', preferred_name='Diabetes') + } + + # Test getting path for a concept + result = snomed_ct_concept_path('111222333', cdb) + + # Should return a dict with node_path and links + self.assertIn('node_path', result) + self.assertIn('links', result) + + # The node_path should contain the root node structure + self.assertEqual(result['node_path']['cui'], '138875005') + self.assertEqual(result['node_path']['pretty_name'], 'SNOMED CT Root') + + # Test with non-existent concept + result = snomed_ct_concept_path('nonexistent', cdb) + self.assertEqual(result, []) + + def test_snomed_ct_concept_path_custom_parent(self): + """Test snomed_ct_concept_path with custom parent node.""" + # Create a CDB with hierarchy + cdb = CDB(self.config) + cdb.addl_info = { + 'ch2pt': { + 'root': [], + 'parent': ['root'], + 'child': ['parent'] + } + } + + # Add CUI info + cdb.cui2info = { + 'root': get_new_cui_info(cui='root', preferred_name='Root'), + 'parent': get_new_cui_info(cui='parent', preferred_name='Parent'), + 'child': get_new_cui_info(cui='child', preferred_name='Child') + } + + # Test with custom parent + result = snomed_ct_concept_path('child', cdb, parent_node='parent') + + # Should return a dict with node_path and links + self.assertIn('node_path', result) + self.assertIn('links', result) + + # The node_path should contain the custom parent node structure + self.assertEqual(result['node_path']['cui'], 'parent') + self.assertEqual(result['node_path']['pretty_name'], 'Parent') + + def test_merge_cdb_context_vector_weights(self): + """Test that context vectors are properly weighted during merging.""" + # Create CDBs with known context vectors for testing + cdb1 = CDB(self.config) + cdb2 = CDB(self.config) + + # Create known vectors + vec1_long = np.array([1.0, 2.0, 3.0] + [0.0] * 297) # 300 dimensions + vec2_long = np.array([4.0, 5.0, 6.0] + [0.0] * 297) + + cui_info1 = get_new_cui_info( + cui="CUI1", + preferred_name="Test", + names={"test"}, + count_train=10, + context_vectors={"long": vec1_long} + ) + cdb1.cui2info["CUI1"] = cui_info1 + + cui_info2 = get_new_cui_info( + cui="CUI1", + preferred_name="Test", + names={"test"}, + count_train=20, + context_vectors={"long": vec2_long} + ) + cdb2.cui2info["CUI1"] = cui_info2 + + # Merge with equal weights (overwrite_training=0) + merged_cdb = merge_cdb(cdb1, cdb2, overwrite_training=0, full_build=True) + + # Check that the merged vector is properly weighted + merged_vec = merged_cdb.cui2info["CUI1"]['context_vectors']['long'] + expected_vec = (10/30) * vec1_long + (20/30) * vec2_long + + np.testing.assert_array_almost_equal(merged_vec, expected_vec, decimal=10) + + def test_merge_cdb_tags_merging(self): + """Test that tags are properly merged.""" + cdb1 = CDB(self.config) + cdb2 = CDB(self.config) + + cui_info1 = get_new_cui_info( + cui="CUI1", + preferred_name="Test", + names={"test"}, + tags=["tag1", "tag2"] + ) + cdb1.cui2info["CUI1"] = cui_info1 + + cui_info2 = get_new_cui_info( + cui="CUI1", + preferred_name="Test", + names={"test"}, + tags=["tag3", "tag4"] + ) + cdb2.cui2info["CUI1"] = cui_info2 + + merged_cdb = merge_cdb(cdb1, cdb2, overwrite_training=0, full_build=True) + + # Tags should be merged + merged_tags = merged_cdb.cui2info["CUI1"]['tags'] + expected_tags = ["tag1", "tag2", "tag3", "tag4"] + self.assertEqual(merged_tags, expected_tags) + + +if __name__ == '__main__': + unittest.main()