aboutsummaryrefslogtreecommitdiff
path: root/backend/tol_data/gen_reduced_trees.py
diff options
context:
space:
mode:
Diffstat (limited to 'backend/tol_data/gen_reduced_trees.py')
-rwxr-xr-xbackend/tol_data/gen_reduced_trees.py62
1 files changed, 50 insertions, 12 deletions
diff --git a/backend/tol_data/gen_reduced_trees.py b/backend/tol_data/gen_reduced_trees.py
index 3742544..ce628f7 100755
--- a/backend/tol_data/gen_reduced_trees.py
+++ b/backend/tol_data/gen_reduced_trees.py
@@ -14,12 +14,14 @@ Creates reduced versions of the tree in the database:
removing some more, despite any node descriptions.
"""
-import sys, re
+import argparse
+import sys
+import re
import sqlite3
DB_FILE = 'data.db'
PICKED_NODES_FILE = 'picked_nodes.txt'
-#
+
COMP_NAME_REGEX = re.compile(r'\[.+ \+ .+]') # Used to recognise composite nodes
class Node:
@@ -30,16 +32,18 @@ class Node:
self.tips = tips
self.pSupport = pSupport
+# ========== For data generation ==========
+
def genData(tree: str, dbFile: str, pickedNodesFile: str) -> None:
print('Opening database')
dbCon = sqlite3.connect(dbFile)
dbCur = dbCon.cursor()
- #
+
print('Finding root node')
query = 'SELECT name FROM nodes LEFT JOIN edges ON nodes.name = edges.child WHERE edges.parent IS NULL LIMIT 1'
(rootName,) = dbCur.execute(query).fetchone()
print(f'Found \'{rootName}\'')
- #
+
print('=== Getting picked-nodes ===')
pickedNames: set[str] = set()
pickedTreeExists = False
@@ -63,7 +67,7 @@ def genData(tree: str, dbFile: str, pickedNodesFile: str) -> None:
for (name,) in dbCur.execute('SELECT name FROM nodes_p'):
pickedNames.add(name)
print(f'Found {len(pickedNames)} names')
- #
+
if (tree == 'picked' or tree is None) and not pickedTreeExists:
print('=== Generating picked-nodes tree ===')
genPickedNodeTree(dbCur, pickedNames, rootName)
@@ -88,22 +92,27 @@ def genData(tree: str, dbFile: str, pickedNodesFile: str) -> None:
if tree == 'trimmed' or tree is None:
print('=== Generating weakly-trimmed tree ===')
genWeaklyTrimmedTree(dbCur, nodesWithImgDescOrPicked, nodesWithImgOrPicked, rootName)
- #
+
print('Closing database')
dbCon.commit()
dbCon.close()
+
def genPickedNodeTree(dbCur: sqlite3.Cursor, pickedNames: set[str], rootName: str) -> None:
PREF_NUM_CHILDREN = 3 # Include extra children up to this limit
+
print('Getting ancestors')
nodeMap = genNodeMap(dbCur, pickedNames, 100)
print(f'Result has {len(nodeMap)} nodes')
+
print('Removing composite nodes')
removedNames = removeCompositeNodes(nodeMap)
print(f'Result has {len(nodeMap)} nodes')
+
print('Removing \'collapsible\' nodes')
temp = removeCollapsibleNodes(nodeMap, pickedNames)
removedNames.update(temp)
print(f'Result has {len(nodeMap)} nodes')
+
print('Adding some additional nearby children')
namesToAdd: list[str] = []
iterNum = 0
@@ -111,7 +120,7 @@ def genPickedNodeTree(dbCur: sqlite3.Cursor, pickedNames: set[str], rootName: st
iterNum += 1
if iterNum % 100 == 0:
print(f'At iteration {iterNum}')
- #
+
numChildren = len(node.children)
if numChildren < PREF_NUM_CHILDREN:
children = [row[0] for row in dbCur.execute('SELECT child FROM edges where parent = ?', (name,))]
@@ -134,33 +143,44 @@ def genPickedNodeTree(dbCur: sqlite3.Cursor, pickedNames: set[str], rootName: st
parent = None if parent == '' else parent
nodeMap[name] = Node(id, [], parent, 0, pSupport == 1)
print(f'Result has {len(nodeMap)} nodes')
+
print('Updating \'tips\' values')
updateTips(rootName, nodeMap)
+
print('Creating table')
addTreeTables(nodeMap, dbCur, 'p')
+
def genImagesOnlyTree(
dbCur: sqlite3.Cursor,
nodesWithImgOrPicked: set[str],
pickedNames: set[str],
rootName: str) -> None:
+
print('Getting ancestors')
nodeMap = genNodeMap(dbCur, nodesWithImgOrPicked, 1e4)
print(f'Result has {len(nodeMap)} nodes')
+
print('Removing composite nodes')
removeCompositeNodes(nodeMap)
print(f'Result has {len(nodeMap)} nodes')
+
print('Removing \'collapsible\' nodes')
removeCollapsibleNodes(nodeMap, pickedNames)
print(f'Result has {len(nodeMap)} nodes')
+
print('Updating \'tips\' values') # Needed for next trimming step
updateTips(rootName, nodeMap)
+
print('Trimming from nodes with \'many\' children')
trimIfManyChildren(nodeMap, rootName, 300, pickedNames)
print(f'Result has {len(nodeMap)} nodes')
+
print('Updating \'tips\' values')
updateTips(rootName, nodeMap)
+
print('Creating table')
addTreeTables(nodeMap, dbCur, 'i')
+
def genWeaklyTrimmedTree(
dbCur: sqlite3.Cursor,
nodesWithImgDescOrPicked: set[str],
@@ -169,6 +189,7 @@ def genWeaklyTrimmedTree(
print('Getting ancestors')
nodeMap = genNodeMap(dbCur, nodesWithImgDescOrPicked, 1e5)
print(f'Result has {len(nodeMap)} nodes')
+
print('Getting nodes to \'strongly keep\'')
iterNum = 0
nodesFromImgOrPicked: set[str] = set()
@@ -184,19 +205,26 @@ def genWeaklyTrimmedTree(
else:
break
print(f'Node set has {len(nodesFromImgOrPicked)} nodes')
+
print('Removing \'collapsible\' nodes')
removeCollapsibleNodes(nodeMap, nodesWithImgDescOrPicked)
print(f'Result has {len(nodeMap)} nodes')
+
print('Updating \'tips\' values') # Needed for next trimming step
updateTips(rootName, nodeMap)
+
print('Trimming from nodes with \'many\' children')
trimIfManyChildren(nodeMap, rootName, 600, nodesFromImgOrPicked)
print(f'Result has {len(nodeMap)} nodes')
+
print('Updating \'tips\' values')
updateTips(rootName, nodeMap)
+
print('Creating table')
addTreeTables(nodeMap, dbCur, 't')
-# Helper functions
+
+# ========== Helper functions ==========
+
def genNodeMap(dbCur: sqlite3.Cursor, nameSet: set[str], itersBeforePrint = 1) -> dict[str, Node]:
""" Returns a subtree that includes nodes in 'nameSet', as a name-to-Node map """
nodeMap: dict[str, Node] = {}
@@ -206,7 +234,7 @@ def genNodeMap(dbCur: sqlite3.Cursor, nameSet: set[str], itersBeforePrint = 1) -
iterNum += 1
if iterNum % itersBeforePrint == 0:
print(f'At iteration {iterNum}')
- #
+
prevName: str | None = None
while name is not None:
if name not in nodeMap:
@@ -227,6 +255,7 @@ def genNodeMap(dbCur: sqlite3.Cursor, nameSet: set[str], itersBeforePrint = 1) -
nodeMap[name].children.append(prevName)
break
return nodeMap
+
def removeCompositeNodes(nodeMap: dict[str, Node]) -> set[str]:
""" Given a tree, removes composite-name nodes, and returns the removed nodes' names """
namesToRemove: set[str] = set()
@@ -244,10 +273,12 @@ def removeCompositeNodes(nodeMap: dict[str, Node]) -> set[str]:
for name in namesToRemove:
del nodeMap[name]
return namesToRemove
+
def removeCollapsibleNodes(nodeMap: dict[str, Node], nodesToKeep: set[str] = set()) -> set[str]:
""" Given a tree, removes single-child parents, then only-childs,
with given exceptions, and returns the set of removed nodes' names """
namesToRemove: set[str] = set()
+
# Remove single-child parents
for name, node in nodeMap.items():
if len(node.children) == 1 and node.parent is not None and name not in nodesToKeep:
@@ -262,6 +293,7 @@ def removeCollapsibleNodes(nodeMap: dict[str, Node], nodesToKeep: set[str] = set
namesToRemove.add(name)
for name in namesToRemove:
del nodeMap[name]
+
# Remove only-childs (not redundant because 'nodesToKeep' can cause single-child parents to be kept)
namesToRemove.clear()
for name, node in nodeMap.items():
@@ -277,8 +309,9 @@ def removeCollapsibleNodes(nodeMap: dict[str, Node], nodesToKeep: set[str] = set
namesToRemove.add(name)
for name in namesToRemove:
del nodeMap[name]
- #
+
return namesToRemove
+
def trimIfManyChildren(
nodeMap: dict[str, Node], rootName: str, childThreshold: int, nodesToKeep: set[str] = set()) -> None:
namesToRemove: set[str] = set()
@@ -299,14 +332,17 @@ def trimIfManyChildren(
# Recurse on children
for n in node.children:
findTrimmables(n)
+
def markForRemoval(nodeName: str) -> None:
nonlocal nodeMap, namesToRemove
namesToRemove.add(nodeName)
for child in nodeMap[nodeName].children:
markForRemoval(child)
+
findTrimmables(rootName)
for nodeName in namesToRemove:
del nodeMap[nodeName]
+
def updateTips(nodeName: str, nodeMap: dict[str, Node]) -> int:
""" Updates the 'tips' values for a node and it's descendants, returning the node's new 'tips' value """
node = nodeMap[nodeName]
@@ -314,6 +350,7 @@ def updateTips(nodeName: str, nodeMap: dict[str, Node]) -> int:
tips = max(1, tips)
node.tips = tips
return tips
+
def addTreeTables(nodeMap: dict[str, Node], dbCur: sqlite3.Cursor, suffix: str):
""" Adds a tree to the database, as tables nodes_X and edges_X, where X is the given suffix """
nodesTbl = f'nodes_{suffix}'
@@ -328,10 +365,11 @@ def addTreeTables(nodeMap: dict[str, Node], dbCur: sqlite3.Cursor, suffix: str):
pSupport = 1 if nodeMap[childName].pSupport else 0
dbCur.execute(f'INSERT INTO {edgesTbl} VALUES (?, ?, ?)', (name, childName, pSupport))
+# ========== Main block ==========
+
if __name__ == '__main__':
- import argparse
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('--tree', choices=['picked', 'images', 'trimmed'], help='Only generate the specified tree')
args = parser.parse_args()
- #
+
genData(args.tree, DB_FILE, PICKED_NODES_FILE)