diff options
Diffstat (limited to 'backend/tol_data/gen_reduced_trees.py')
| -rwxr-xr-x | backend/tol_data/gen_reduced_trees.py | 62 |
1 files changed, 50 insertions, 12 deletions
diff --git a/backend/tol_data/gen_reduced_trees.py b/backend/tol_data/gen_reduced_trees.py index 3742544..ce628f7 100755 --- a/backend/tol_data/gen_reduced_trees.py +++ b/backend/tol_data/gen_reduced_trees.py @@ -14,12 +14,14 @@ Creates reduced versions of the tree in the database: removing some more, despite any node descriptions. """ -import sys, re +import argparse +import sys +import re import sqlite3 DB_FILE = 'data.db' PICKED_NODES_FILE = 'picked_nodes.txt' -# + COMP_NAME_REGEX = re.compile(r'\[.+ \+ .+]') # Used to recognise composite nodes class Node: @@ -30,16 +32,18 @@ class Node: self.tips = tips self.pSupport = pSupport +# ========== For data generation ========== + def genData(tree: str, dbFile: str, pickedNodesFile: str) -> None: print('Opening database') dbCon = sqlite3.connect(dbFile) dbCur = dbCon.cursor() - # + print('Finding root node') query = 'SELECT name FROM nodes LEFT JOIN edges ON nodes.name = edges.child WHERE edges.parent IS NULL LIMIT 1' (rootName,) = dbCur.execute(query).fetchone() print(f'Found \'{rootName}\'') - # + print('=== Getting picked-nodes ===') pickedNames: set[str] = set() pickedTreeExists = False @@ -63,7 +67,7 @@ def genData(tree: str, dbFile: str, pickedNodesFile: str) -> None: for (name,) in dbCur.execute('SELECT name FROM nodes_p'): pickedNames.add(name) print(f'Found {len(pickedNames)} names') - # + if (tree == 'picked' or tree is None) and not pickedTreeExists: print('=== Generating picked-nodes tree ===') genPickedNodeTree(dbCur, pickedNames, rootName) @@ -88,22 +92,27 @@ def genData(tree: str, dbFile: str, pickedNodesFile: str) -> None: if tree == 'trimmed' or tree is None: print('=== Generating weakly-trimmed tree ===') genWeaklyTrimmedTree(dbCur, nodesWithImgDescOrPicked, nodesWithImgOrPicked, rootName) - # + print('Closing database') dbCon.commit() dbCon.close() + def genPickedNodeTree(dbCur: sqlite3.Cursor, pickedNames: set[str], rootName: str) -> None: PREF_NUM_CHILDREN = 3 # Include extra children up to this limit + print('Getting ancestors') nodeMap = genNodeMap(dbCur, pickedNames, 100) print(f'Result has {len(nodeMap)} nodes') + print('Removing composite nodes') removedNames = removeCompositeNodes(nodeMap) print(f'Result has {len(nodeMap)} nodes') + print('Removing \'collapsible\' nodes') temp = removeCollapsibleNodes(nodeMap, pickedNames) removedNames.update(temp) print(f'Result has {len(nodeMap)} nodes') + print('Adding some additional nearby children') namesToAdd: list[str] = [] iterNum = 0 @@ -111,7 +120,7 @@ def genPickedNodeTree(dbCur: sqlite3.Cursor, pickedNames: set[str], rootName: st iterNum += 1 if iterNum % 100 == 0: print(f'At iteration {iterNum}') - # + numChildren = len(node.children) if numChildren < PREF_NUM_CHILDREN: children = [row[0] for row in dbCur.execute('SELECT child FROM edges where parent = ?', (name,))] @@ -134,33 +143,44 @@ def genPickedNodeTree(dbCur: sqlite3.Cursor, pickedNames: set[str], rootName: st parent = None if parent == '' else parent nodeMap[name] = Node(id, [], parent, 0, pSupport == 1) print(f'Result has {len(nodeMap)} nodes') + print('Updating \'tips\' values') updateTips(rootName, nodeMap) + print('Creating table') addTreeTables(nodeMap, dbCur, 'p') + def genImagesOnlyTree( dbCur: sqlite3.Cursor, nodesWithImgOrPicked: set[str], pickedNames: set[str], rootName: str) -> None: + print('Getting ancestors') nodeMap = genNodeMap(dbCur, nodesWithImgOrPicked, 1e4) print(f'Result has {len(nodeMap)} nodes') + print('Removing composite nodes') removeCompositeNodes(nodeMap) print(f'Result has {len(nodeMap)} nodes') + print('Removing \'collapsible\' nodes') removeCollapsibleNodes(nodeMap, pickedNames) print(f'Result has {len(nodeMap)} nodes') + print('Updating \'tips\' values') # Needed for next trimming step updateTips(rootName, nodeMap) + print('Trimming from nodes with \'many\' children') trimIfManyChildren(nodeMap, rootName, 300, pickedNames) print(f'Result has {len(nodeMap)} nodes') + print('Updating \'tips\' values') updateTips(rootName, nodeMap) + print('Creating table') addTreeTables(nodeMap, dbCur, 'i') + def genWeaklyTrimmedTree( dbCur: sqlite3.Cursor, nodesWithImgDescOrPicked: set[str], @@ -169,6 +189,7 @@ def genWeaklyTrimmedTree( print('Getting ancestors') nodeMap = genNodeMap(dbCur, nodesWithImgDescOrPicked, 1e5) print(f'Result has {len(nodeMap)} nodes') + print('Getting nodes to \'strongly keep\'') iterNum = 0 nodesFromImgOrPicked: set[str] = set() @@ -184,19 +205,26 @@ def genWeaklyTrimmedTree( else: break print(f'Node set has {len(nodesFromImgOrPicked)} nodes') + print('Removing \'collapsible\' nodes') removeCollapsibleNodes(nodeMap, nodesWithImgDescOrPicked) print(f'Result has {len(nodeMap)} nodes') + print('Updating \'tips\' values') # Needed for next trimming step updateTips(rootName, nodeMap) + print('Trimming from nodes with \'many\' children') trimIfManyChildren(nodeMap, rootName, 600, nodesFromImgOrPicked) print(f'Result has {len(nodeMap)} nodes') + print('Updating \'tips\' values') updateTips(rootName, nodeMap) + print('Creating table') addTreeTables(nodeMap, dbCur, 't') -# Helper functions + +# ========== Helper functions ========== + def genNodeMap(dbCur: sqlite3.Cursor, nameSet: set[str], itersBeforePrint = 1) -> dict[str, Node]: """ Returns a subtree that includes nodes in 'nameSet', as a name-to-Node map """ nodeMap: dict[str, Node] = {} @@ -206,7 +234,7 @@ def genNodeMap(dbCur: sqlite3.Cursor, nameSet: set[str], itersBeforePrint = 1) - iterNum += 1 if iterNum % itersBeforePrint == 0: print(f'At iteration {iterNum}') - # + prevName: str | None = None while name is not None: if name not in nodeMap: @@ -227,6 +255,7 @@ def genNodeMap(dbCur: sqlite3.Cursor, nameSet: set[str], itersBeforePrint = 1) - nodeMap[name].children.append(prevName) break return nodeMap + def removeCompositeNodes(nodeMap: dict[str, Node]) -> set[str]: """ Given a tree, removes composite-name nodes, and returns the removed nodes' names """ namesToRemove: set[str] = set() @@ -244,10 +273,12 @@ def removeCompositeNodes(nodeMap: dict[str, Node]) -> set[str]: for name in namesToRemove: del nodeMap[name] return namesToRemove + def removeCollapsibleNodes(nodeMap: dict[str, Node], nodesToKeep: set[str] = set()) -> set[str]: """ Given a tree, removes single-child parents, then only-childs, with given exceptions, and returns the set of removed nodes' names """ namesToRemove: set[str] = set() + # Remove single-child parents for name, node in nodeMap.items(): if len(node.children) == 1 and node.parent is not None and name not in nodesToKeep: @@ -262,6 +293,7 @@ def removeCollapsibleNodes(nodeMap: dict[str, Node], nodesToKeep: set[str] = set namesToRemove.add(name) for name in namesToRemove: del nodeMap[name] + # Remove only-childs (not redundant because 'nodesToKeep' can cause single-child parents to be kept) namesToRemove.clear() for name, node in nodeMap.items(): @@ -277,8 +309,9 @@ def removeCollapsibleNodes(nodeMap: dict[str, Node], nodesToKeep: set[str] = set namesToRemove.add(name) for name in namesToRemove: del nodeMap[name] - # + return namesToRemove + def trimIfManyChildren( nodeMap: dict[str, Node], rootName: str, childThreshold: int, nodesToKeep: set[str] = set()) -> None: namesToRemove: set[str] = set() @@ -299,14 +332,17 @@ def trimIfManyChildren( # Recurse on children for n in node.children: findTrimmables(n) + def markForRemoval(nodeName: str) -> None: nonlocal nodeMap, namesToRemove namesToRemove.add(nodeName) for child in nodeMap[nodeName].children: markForRemoval(child) + findTrimmables(rootName) for nodeName in namesToRemove: del nodeMap[nodeName] + def updateTips(nodeName: str, nodeMap: dict[str, Node]) -> int: """ Updates the 'tips' values for a node and it's descendants, returning the node's new 'tips' value """ node = nodeMap[nodeName] @@ -314,6 +350,7 @@ def updateTips(nodeName: str, nodeMap: dict[str, Node]) -> int: tips = max(1, tips) node.tips = tips return tips + def addTreeTables(nodeMap: dict[str, Node], dbCur: sqlite3.Cursor, suffix: str): """ Adds a tree to the database, as tables nodes_X and edges_X, where X is the given suffix """ nodesTbl = f'nodes_{suffix}' @@ -328,10 +365,11 @@ def addTreeTables(nodeMap: dict[str, Node], dbCur: sqlite3.Cursor, suffix: str): pSupport = 1 if nodeMap[childName].pSupport else 0 dbCur.execute(f'INSERT INTO {edgesTbl} VALUES (?, ?, ?)', (name, childName, pSupport)) +# ========== Main block ========== + if __name__ == '__main__': - import argparse parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument('--tree', choices=['picked', 'images', 'trimmed'], help='Only generate the specified tree') args = parser.parse_args() - # + genData(args.tree, DB_FILE, PICKED_NODES_FILE) |
