compound_strategy.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from spacy.tokens import Token
  2. from word_processor.types import DEP_TYPES
  3. def compound_strategy(doc) -> [[Token]]:
  4. """
  5. Should return an arrays of variable names based on compound strategy
  6. Uses adverbial strategy and also adds compounds to nouns
  7. e.g.
  8. Reads black phone numbers
  9. Will treat PhoneNumbers as a single entity.
  10. :param doc: spacy document
  11. :return Array of strings
  12. """
  13. suggestions = []
  14. for token in doc:
  15. if token.dep_ == DEP_TYPES['ROOT']:
  16. suggestions = compound_dfs(token)
  17. break
  18. return suggestions
  19. INVALID_DEP = ['aux', 'prep']
  20. INVALID_POS = ['DET', 'AUX', 'ADP']
  21. FLIP_DEP = ['compound', 'amod']
  22. def is_valid(node, existing=None):
  23. return (
  24. not (node.pos_ in INVALID_POS)
  25. and not (node.dep_ in INVALID_DEP)
  26. and not (node in existing if existing else False)
  27. )
  28. def should_flip(node):
  29. return node.dep_ in FLIP_DEP
  30. # todo - this is dfs
  31. def compound_dfs(node, result=None, output=None):
  32. if output is None:
  33. output = []
  34. if result is None:
  35. result = []
  36. has_parent = node.head != node
  37. has_children = (node.n_lefts + node.n_rights) > 0
  38. has_both_directions = node.n_lefts > 0 and node.n_rights > 0
  39. if should_flip(node) and has_parent and is_valid(node, result):
  40. compound_dfs(node.head, [*result[:-1], node], output)
  41. # dunno if i should return here
  42. return output
  43. if node.pos_ == 'VERB' and has_both_directions:
  44. for lefty in node.lefts:
  45. for righty in node.rights:
  46. if is_valid(lefty, result):
  47. compound_dfs(righty, [*result, lefty], output)
  48. # if is_valid(righty):
  49. # compound_dfs(lefty, [*result, righty], output)
  50. elif has_children:
  51. for u in node.children:
  52. valid_results = [*result, node] if is_valid(node, result) else result
  53. compound_dfs(u, valid_results, output)
  54. else:
  55. valid_results = [*result, node] if is_valid(node, result) else result
  56. output.append(valid_results)
  57. return output