diff --git a/saml-core/src/main/java/org/keycloak/saml/common/util/TransformerUtil.java b/saml-core/src/main/java/org/keycloak/saml/common/util/TransformerUtil.java index 14c3fe012e..b90db1b714 100755 --- a/saml-core/src/main/java/org/keycloak/saml/common/util/TransformerUtil.java +++ b/saml-core/src/main/java/org/keycloak/saml/common/util/TransformerUtil.java @@ -209,8 +209,6 @@ public class TransformerUtil { if (!(outputTarget instanceof DOMResult)) throw logger.wrongTypeError("outputTarget should be a dom result"); - String rootTag = null; - StAXSource staxSource = (StAXSource) xmlSource; XMLEventReader xmlEventReader = staxSource.getXMLEventReader(); if (xmlEventReader == null) @@ -227,7 +225,6 @@ public class TransformerUtil { throw new TransformerException(ErrorCodes.WRITER_SHOULD_START_ELEMENT); StartElement rootElement = (StartElement) xmlEvent; - rootTag = StaxParserUtil.getElementName(rootElement); CustomHolder holder = new CustomHolder(doc, false); Element docRoot = handleStartElement(xmlEventReader, rootElement, holder); Node parent = doc.importNode(docRoot, true); @@ -243,6 +240,8 @@ public class TransformerUtil { while (xmlEventReader.hasNext()) { xmlEvent = StaxParserUtil.getNextEvent(xmlEventReader); int type = xmlEvent.getEventType(); + Node top = null; + switch (type) { case XMLEvent.START_ELEMENT: StartElement startElement = (StartElement) xmlEvent; @@ -250,13 +249,11 @@ public class TransformerUtil { Element docStartElement = handleStartElement(xmlEventReader, startElement, holder); Node el = doc.importNode(docStartElement, true); - Node top = null; - - if (!stack.isEmpty()) { + if (! stack.isEmpty()) { top = stack.peek(); } - if (!holder.encounteredTextNode) { + if (! holder.encounteredTextNode) { stack.push(el); } @@ -265,15 +262,15 @@ public class TransformerUtil { else top.appendChild(el); break; + case XMLEvent.END_ELEMENT: - EndElement endElement = (EndElement) xmlEvent; - String endTag = StaxParserUtil.getElementName(endElement); - if (rootTag.equals(endTag)) - return; // We are done with the dom parsing - else { - if (!stack.isEmpty()) - stack.pop(); + top = stack.pop(); + + if (! (top instanceof Element)) { + throw new TransformerException(ErrorCodes.UNKNOWN_END_ELEMENT); } + if (stack.isEmpty()) + return; // We are done with the dom parsing break; } } diff --git a/saml-core/src/test/java/org/keycloak/saml/common/util/StaxParserUtilTest.java b/saml-core/src/test/java/org/keycloak/saml/common/util/StaxParserUtilTest.java index 77bf1a84b5..14438c4eeb 100644 --- a/saml-core/src/test/java/org/keycloak/saml/common/util/StaxParserUtilTest.java +++ b/saml-core/src/test/java/org/keycloak/saml/common/util/StaxParserUtilTest.java @@ -18,18 +18,23 @@ package org.keycloak.saml.common.util; import org.keycloak.saml.common.exceptions.ParsingException; import java.nio.charset.Charset; +import java.util.NoSuchElementException; import javax.xml.stream.XMLEventReader; import javax.xml.stream.XMLStreamException; import javax.xml.stream.events.Characters; +import javax.xml.stream.events.EndDocument; import javax.xml.stream.events.EndElement; import javax.xml.stream.events.StartDocument; import javax.xml.stream.events.StartElement; import javax.xml.stream.events.XMLEvent; import org.apache.commons.io.IOUtils; import org.hamcrest.Matcher; +import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.w3c.dom.Element; +import org.w3c.dom.Text; import static org.junit.Assert.assertThat; import static org.hamcrest.CoreMatchers.*; @@ -175,4 +180,38 @@ public class StaxParserUtilTest { reader.nextEvent(); } + @Test + public void testGetDOMElementSameElements() throws XMLStreamException, ParsingException { + String xml = "b"; + XMLEventReader reader = StaxParserUtil.getXMLEventReader(IOUtils.toInputStream(xml, Charset.defaultCharset())); + + assertThat(reader.nextEvent(), instanceOf(StartDocument.class)); + + assertStartTag(reader.nextEvent(), "root"); + + Element element = StaxParserUtil.getDOMElement(reader); + + assertThat(element.getNodeName(), is("test")); + assertThat(element.getChildNodes().getLength(), is(1)); + + assertThat(element.getChildNodes().item(0), instanceOf(Element.class)); + Element e = (Element) element.getChildNodes().item(0); + assertThat(e.getNodeName(), is("test")); + + assertThat(e.getChildNodes().getLength(), is(1)); + assertThat(e.getChildNodes().item(0), instanceOf(Element.class)); + Element e1 = (Element) e.getChildNodes().item(0); + assertThat(e1.getNodeName(), is("a")); + + assertThat(e1.getChildNodes().getLength(), is(1)); + assertThat(e1.getChildNodes().item(0), instanceOf(Text.class)); + assertThat(((Text) e1.getChildNodes().item(0)).getWholeText(), is("b")); + + assertEndTag(reader.nextEvent(), "root"); + assertThat(reader.nextEvent(), instanceOf(EndDocument.class)); + + expectedException.expect(NoSuchElementException.class); + Assert.fail(String.valueOf(reader.nextEvent())); + } + }