main
  1#!/usr/bin/env python3
  2"""
  3Utilities for editing OOXML documents.
  4
  5This module provides XMLEditor, a tool for manipulating XML files with support for
  6line-number-based node finding and DOM manipulation. Each element is automatically
  7annotated with its original line and column position during parsing.
  8
  9Example usage:
 10    editor = XMLEditor("document.xml")
 11
 12    # Find node by line number or range
 13    elem = editor.get_node(tag="w:r", line_number=519)
 14    elem = editor.get_node(tag="w:p", line_number=range(100, 200))
 15
 16    # Find node by text content
 17    elem = editor.get_node(tag="w:p", contains="specific text")
 18
 19    # Find node by attributes
 20    elem = editor.get_node(tag="w:r", attrs={"w:id": "target"})
 21
 22    # Combine filters
 23    elem = editor.get_node(tag="w:p", line_number=range(1, 50), contains="text")
 24
 25    # Replace, insert, or manipulate
 26    new_elem = editor.replace_node(elem, "<w:r><w:t>new text</w:t></w:r>")
 27    editor.insert_after(new_elem, "<w:r><w:t>more</w:t></w:r>")
 28
 29    # Save changes
 30    editor.save()
 31"""
 32
 33import html
 34from pathlib import Path
 35from typing import Optional, Union
 36
 37import defusedxml.minidom
 38import defusedxml.sax
 39
 40
 41class XMLEditor:
 42    """
 43    Editor for manipulating OOXML XML files with line-number-based node finding.
 44
 45    This class parses XML files and tracks the original line and column position
 46    of each element. This enables finding nodes by their line number in the original
 47    file, which is useful when working with Read tool output.
 48
 49    Attributes:
 50        xml_path: Path to the XML file being edited
 51        encoding: Detected encoding of the XML file ('ascii' or 'utf-8')
 52        dom: Parsed DOM tree with parse_position attributes on elements
 53    """
 54
 55    def __init__(self, xml_path):
 56        """
 57        Initialize with path to XML file and parse with line number tracking.
 58
 59        Args:
 60            xml_path: Path to XML file to edit (str or Path)
 61
 62        Raises:
 63            ValueError: If the XML file does not exist
 64        """
 65        self.xml_path = Path(xml_path)
 66        if not self.xml_path.exists():
 67            raise ValueError(f"XML file not found: {xml_path}")
 68
 69        with open(self.xml_path, "rb") as f:
 70            header = f.read(200).decode("utf-8", errors="ignore")
 71        self.encoding = "ascii" if 'encoding="ascii"' in header else "utf-8"
 72
 73        parser = _create_line_tracking_parser()
 74        self.dom = defusedxml.minidom.parse(str(self.xml_path), parser)
 75
 76    def get_node(
 77        self,
 78        tag: str,
 79        attrs: Optional[dict[str, str]] = None,
 80        line_number: Optional[Union[int, range]] = None,
 81        contains: Optional[str] = None,
 82    ):
 83        """
 84        Get a DOM element by tag and identifier.
 85
 86        Finds an element by either its line number in the original file or by
 87        matching attribute values. Exactly one match must be found.
 88
 89        Args:
 90            tag: The XML tag name (e.g., "w:del", "w:ins", "w:r")
 91            attrs: Dictionary of attribute name-value pairs to match (e.g., {"w:id": "1"})
 92            line_number: Line number (int) or line range (range) in original XML file (1-indexed)
 93            contains: Text string that must appear in any text node within the element.
 94                      Supports both entity notation (&#8220;) and Unicode characters (\u201c).
 95
 96        Returns:
 97            defusedxml.minidom.Element: The matching DOM element
 98
 99        Raises:
100            ValueError: If node not found or multiple matches found
101
102        Example:
103            elem = editor.get_node(tag="w:r", line_number=519)
104            elem = editor.get_node(tag="w:r", line_number=range(100, 200))
105            elem = editor.get_node(tag="w:del", attrs={"w:id": "1"})
106            elem = editor.get_node(tag="w:p", attrs={"w14:paraId": "12345678"})
107            elem = editor.get_node(tag="w:commentRangeStart", attrs={"w:id": "0"})
108            elem = editor.get_node(tag="w:p", contains="specific text")
109            elem = editor.get_node(tag="w:t", contains="&#8220;Agreement")  # Entity notation
110            elem = editor.get_node(tag="w:t", contains="\u201cAgreement")   # Unicode character
111        """
112        matches = []
113        for elem in self.dom.getElementsByTagName(tag):
114            # Check line_number filter
115            if line_number is not None:
116                parse_pos = getattr(elem, "parse_position", (None,))
117                elem_line = parse_pos[0]
118
119                # Handle both single line number and range
120                if isinstance(line_number, range):
121                    if elem_line not in line_number:
122                        continue
123                else:
124                    if elem_line != line_number:
125                        continue
126
127            # Check attrs filter
128            if attrs is not None:
129                if not all(
130                    elem.getAttribute(attr_name) == attr_value
131                    for attr_name, attr_value in attrs.items()
132                ):
133                    continue
134
135            # Check contains filter
136            if contains is not None:
137                elem_text = self._get_element_text(elem)
138                # Normalize the search string: convert HTML entities to Unicode characters
139                # This allows searching for both "&#8220;Rowan" and ""Rowan"
140                normalized_contains = html.unescape(contains)
141                if normalized_contains not in elem_text:
142                    continue
143
144            # If all applicable filters passed, this is a match
145            matches.append(elem)
146
147        if not matches:
148            # Build descriptive error message
149            filters = []
150            if line_number is not None:
151                line_str = (
152                    f"lines {line_number.start}-{line_number.stop - 1}"
153                    if isinstance(line_number, range)
154                    else f"line {line_number}"
155                )
156                filters.append(f"at {line_str}")
157            if attrs is not None:
158                filters.append(f"with attributes {attrs}")
159            if contains is not None:
160                filters.append(f"containing '{contains}'")
161
162            filter_desc = " ".join(filters) if filters else ""
163            base_msg = f"Node not found: <{tag}> {filter_desc}".strip()
164
165            # Add helpful hint based on filters used
166            if contains:
167                hint = "Text may be split across elements or use different wording."
168            elif line_number:
169                hint = "Line numbers may have changed if document was modified."
170            elif attrs:
171                hint = "Verify attribute values are correct."
172            else:
173                hint = "Try adding filters (attrs, line_number, or contains)."
174
175            raise ValueError(f"{base_msg}. {hint}")
176        if len(matches) > 1:
177            raise ValueError(
178                f"Multiple nodes found: <{tag}>. "
179                f"Add more filters (attrs, line_number, or contains) to narrow the search."
180            )
181        return matches[0]
182
183    def _get_element_text(self, elem):
184        """
185        Recursively extract all text content from an element.
186
187        Skips text nodes that contain only whitespace (spaces, tabs, newlines),
188        which typically represent XML formatting rather than document content.
189
190        Args:
191            elem: defusedxml.minidom.Element to extract text from
192
193        Returns:
194            str: Concatenated text from all non-whitespace text nodes within the element
195        """
196        text_parts = []
197        for node in elem.childNodes:
198            if node.nodeType == node.TEXT_NODE:
199                # Skip whitespace-only text nodes (XML formatting)
200                if node.data.strip():
201                    text_parts.append(node.data)
202            elif node.nodeType == node.ELEMENT_NODE:
203                text_parts.append(self._get_element_text(node))
204        return "".join(text_parts)
205
206    def replace_node(self, elem, new_content):
207        """
208        Replace a DOM element with new XML content.
209
210        Args:
211            elem: defusedxml.minidom.Element to replace
212            new_content: String containing XML to replace the node with
213
214        Returns:
215            List[defusedxml.minidom.Node]: All inserted nodes
216
217        Example:
218            new_nodes = editor.replace_node(old_elem, "<w:r><w:t>text</w:t></w:r>")
219        """
220        parent = elem.parentNode
221        nodes = self._parse_fragment(new_content)
222        for node in nodes:
223            parent.insertBefore(node, elem)
224        parent.removeChild(elem)
225        return nodes
226
227    def insert_after(self, elem, xml_content):
228        """
229        Insert XML content after a DOM element.
230
231        Args:
232            elem: defusedxml.minidom.Element to insert after
233            xml_content: String containing XML to insert
234
235        Returns:
236            List[defusedxml.minidom.Node]: All inserted nodes
237
238        Example:
239            new_nodes = editor.insert_after(elem, "<w:r><w:t>text</w:t></w:r>")
240        """
241        parent = elem.parentNode
242        next_sibling = elem.nextSibling
243        nodes = self._parse_fragment(xml_content)
244        for node in nodes:
245            if next_sibling:
246                parent.insertBefore(node, next_sibling)
247            else:
248                parent.appendChild(node)
249        return nodes
250
251    def insert_before(self, elem, xml_content):
252        """
253        Insert XML content before a DOM element.
254
255        Args:
256            elem: defusedxml.minidom.Element to insert before
257            xml_content: String containing XML to insert
258
259        Returns:
260            List[defusedxml.minidom.Node]: All inserted nodes
261
262        Example:
263            new_nodes = editor.insert_before(elem, "<w:r><w:t>text</w:t></w:r>")
264        """
265        parent = elem.parentNode
266        nodes = self._parse_fragment(xml_content)
267        for node in nodes:
268            parent.insertBefore(node, elem)
269        return nodes
270
271    def append_to(self, elem, xml_content):
272        """
273        Append XML content as a child of a DOM element.
274
275        Args:
276            elem: defusedxml.minidom.Element to append to
277            xml_content: String containing XML to append
278
279        Returns:
280            List[defusedxml.minidom.Node]: All inserted nodes
281
282        Example:
283            new_nodes = editor.append_to(elem, "<w:r><w:t>text</w:t></w:r>")
284        """
285        nodes = self._parse_fragment(xml_content)
286        for node in nodes:
287            elem.appendChild(node)
288        return nodes
289
290    def get_next_rid(self):
291        """Get the next available rId for relationships files."""
292        max_id = 0
293        for rel_elem in self.dom.getElementsByTagName("Relationship"):
294            rel_id = rel_elem.getAttribute("Id")
295            if rel_id.startswith("rId"):
296                try:
297                    max_id = max(max_id, int(rel_id[3:]))
298                except ValueError:
299                    pass
300        return f"rId{max_id + 1}"
301
302    def save(self):
303        """
304        Save the edited XML back to the file.
305
306        Serializes the DOM tree and writes it back to the original file path,
307        preserving the original encoding (ascii or utf-8).
308        """
309        content = self.dom.toxml(encoding=self.encoding)
310        self.xml_path.write_bytes(content)
311
312    def _parse_fragment(self, xml_content):
313        """
314        Parse XML fragment and return list of imported nodes.
315
316        Args:
317            xml_content: String containing XML fragment
318
319        Returns:
320            List of defusedxml.minidom.Node objects imported into this document
321
322        Raises:
323            AssertionError: If fragment contains no element nodes
324        """
325        # Extract namespace declarations from the root document element
326        root_elem = self.dom.documentElement
327        namespaces = []
328        if root_elem and root_elem.attributes:
329            for i in range(root_elem.attributes.length):
330                attr = root_elem.attributes.item(i)
331                if attr.name.startswith("xmlns"):  # type: ignore
332                    namespaces.append(f'{attr.name}="{attr.value}"')  # type: ignore
333
334        ns_decl = " ".join(namespaces)
335        wrapper = f"<root {ns_decl}>{xml_content}</root>"
336        fragment_doc = defusedxml.minidom.parseString(wrapper)
337        nodes = [
338            self.dom.importNode(child, deep=True)
339            for child in fragment_doc.documentElement.childNodes  # type: ignore
340        ]
341        elements = [n for n in nodes if n.nodeType == n.ELEMENT_NODE]
342        assert elements, "Fragment must contain at least one element"
343        return nodes
344
345
346def _create_line_tracking_parser():
347    """
348    Create a SAX parser that tracks line and column numbers for each element.
349
350    Monkey patches the SAX content handler to store the current line and column
351    position from the underlying expat parser onto each element as a parse_position
352    attribute (line, column) tuple.
353
354    Returns:
355        defusedxml.sax.xmlreader.XMLReader: Configured SAX parser
356    """
357
358    def set_content_handler(dom_handler):
359        def startElementNS(name, tagName, attrs):
360            orig_start_cb(name, tagName, attrs)
361            cur_elem = dom_handler.elementStack[-1]
362            cur_elem.parse_position = (
363                parser._parser.CurrentLineNumber,  # type: ignore
364                parser._parser.CurrentColumnNumber,  # type: ignore
365            )
366
367        orig_start_cb = dom_handler.startElementNS
368        dom_handler.startElementNS = startElementNS
369        orig_set_content_handler(dom_handler)
370
371    parser = defusedxml.sax.make_parser()
372    orig_set_content_handler = parser.setContentHandler
373    parser.setContentHandler = set_content_handler  # type: ignore
374    return parser