404 lines
8.1 KiB
Go
404 lines
8.1 KiB
Go
package html
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
"github.com/tdewolff/parse/v2"
|
|
"github.com/tdewolff/parse/v2/css"
|
|
)
|
|
|
|
type AST struct {
|
|
Children []*Tag
|
|
Text []byte
|
|
}
|
|
|
|
func (ast *AST) String() string {
|
|
sb := strings.Builder{}
|
|
for i, child := range ast.Children {
|
|
if i != 0 {
|
|
sb.WriteString("\n")
|
|
}
|
|
sb.WriteString(child.ASTString())
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
type Attr struct {
|
|
Key, Val []byte
|
|
}
|
|
|
|
func (attr *Attr) String() string {
|
|
return fmt.Sprintf(`%s="%s"`, string(attr.Key), string(attr.Val))
|
|
}
|
|
|
|
type Tag struct {
|
|
Root *AST
|
|
Parent *Tag
|
|
Prev, Next *Tag
|
|
Children []*Tag
|
|
Index int
|
|
|
|
Name []byte
|
|
Attrs []Attr
|
|
textStart, textEnd int
|
|
}
|
|
|
|
func (tag *Tag) getAttr(key []byte) ([]byte, bool) {
|
|
for _, attr := range tag.Attrs {
|
|
if bytes.Equal(key, attr.Key) {
|
|
return attr.Val, true
|
|
}
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
func (tag *Tag) GetAttr(key string) (string, bool) {
|
|
val, ok := tag.getAttr([]byte(key))
|
|
return string(val), ok
|
|
}
|
|
|
|
func (tag *Tag) Text() string {
|
|
return string(tag.Root.Text[tag.textStart:tag.textEnd])
|
|
}
|
|
|
|
func (tag *Tag) String() string {
|
|
sb := strings.Builder{}
|
|
sb.WriteString("<")
|
|
sb.Write(tag.Name)
|
|
for _, attr := range tag.Attrs {
|
|
sb.WriteString(" ")
|
|
sb.WriteString(attr.String())
|
|
}
|
|
sb.WriteString(">")
|
|
return sb.String()
|
|
}
|
|
|
|
func (tag *Tag) ASTString() string {
|
|
sb := strings.Builder{}
|
|
sb.WriteString(tag.String())
|
|
for _, child := range tag.Children {
|
|
sb.WriteString("\n ")
|
|
s := child.ASTString()
|
|
s = strings.ReplaceAll(s, "\n", "\n ")
|
|
sb.WriteString(s)
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
func Parse(r *parse.Input) (*AST, error) {
|
|
ast := &AST{}
|
|
root := &Tag{}
|
|
cur := root
|
|
|
|
l := NewLexer(r)
|
|
for {
|
|
tt, data := l.Next()
|
|
switch tt {
|
|
case ErrorToken:
|
|
if err := l.Err(); err != io.EOF {
|
|
return nil, err
|
|
}
|
|
ast.Children = root.Children
|
|
return ast, nil
|
|
case TextToken:
|
|
ast.Text = append(ast.Text, data...)
|
|
case StartTagToken:
|
|
child := &Tag{
|
|
Root: ast,
|
|
Parent: cur,
|
|
Index: len(cur.Children),
|
|
Name: l.Text(),
|
|
textStart: len(ast.Text),
|
|
}
|
|
if 0 < len(cur.Children) {
|
|
child.Prev = cur.Children[len(cur.Children)-1]
|
|
child.Prev.Next = child
|
|
}
|
|
cur.Children = append(cur.Children, child)
|
|
cur = child
|
|
case AttributeToken:
|
|
val := l.AttrVal()
|
|
if 0 < len(val) && (val[0] == '"' || val[0] == '\'') {
|
|
val = val[1 : len(val)-1]
|
|
}
|
|
cur.Attrs = append(cur.Attrs, Attr{l.AttrKey(), val})
|
|
case StartTagCloseToken:
|
|
if voidTags[string(cur.Name)] {
|
|
cur.textEnd = len(ast.Text)
|
|
cur = cur.Parent
|
|
}
|
|
case EndTagToken, StartTagVoidToken:
|
|
start := cur
|
|
for start != root && !bytes.Equal(l.Text(), start.Name) {
|
|
start = start.Parent
|
|
}
|
|
if start == root {
|
|
// ignore
|
|
} else {
|
|
parent := start.Parent
|
|
for cur != parent {
|
|
cur.textEnd = len(ast.Text)
|
|
cur = cur.Parent
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ast *AST) Query(s string) (*Tag, error) {
|
|
sel, err := ParseSelector(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, child := range ast.Children {
|
|
if match := child.query(sel); match != nil {
|
|
return match, nil
|
|
}
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (tag *Tag) query(sel selector) *Tag {
|
|
if sel.AppliesTo(tag) {
|
|
return tag
|
|
}
|
|
for _, child := range tag.Children {
|
|
if match := child.query(sel); match != nil {
|
|
return match
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ast *AST) QueryAll(s string) ([]*Tag, error) {
|
|
sel, err := ParseSelector(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
matches := []*Tag{}
|
|
for _, child := range ast.Children {
|
|
child.queryAll(&matches, sel)
|
|
}
|
|
return matches, nil
|
|
}
|
|
|
|
func (tag *Tag) queryAll(matches *[]*Tag, sel selector) {
|
|
if sel.AppliesTo(tag) {
|
|
*matches = append(*matches, tag)
|
|
}
|
|
for _, child := range tag.Children {
|
|
child.queryAll(matches, sel)
|
|
}
|
|
}
|
|
|
|
type attrSelector struct {
|
|
op byte // empty, =, ~, |
|
|
attr []byte
|
|
val []byte
|
|
}
|
|
|
|
func (sel attrSelector) AppliesTo(tag *Tag) bool {
|
|
val, ok := tag.getAttr(sel.attr)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
switch sel.op {
|
|
case 0:
|
|
return true
|
|
case '=':
|
|
return bytes.Equal(val, sel.val)
|
|
case '~':
|
|
if 0 < len(sel.val) {
|
|
vals := bytes.Split(val, []byte(" "))
|
|
for _, val := range vals {
|
|
if bytes.Equal(val, sel.val) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
case '|':
|
|
return bytes.Equal(val, sel.val) || bytes.HasPrefix(val, append(sel.val, '-'))
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (attr attrSelector) String() string {
|
|
sb := strings.Builder{}
|
|
sb.Write(attr.attr)
|
|
if attr.op != 0 {
|
|
sb.WriteByte(attr.op)
|
|
if attr.op != '=' {
|
|
sb.WriteByte('=')
|
|
}
|
|
sb.WriteByte('"')
|
|
sb.Write(attr.val)
|
|
sb.WriteByte('"')
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
type selectorNode struct {
|
|
typ []byte // is * for universal
|
|
attrs []attrSelector
|
|
op byte // space or >, last is NULL
|
|
}
|
|
|
|
func (sel selectorNode) AppliesTo(tag *Tag) bool {
|
|
if 0 < len(sel.typ) && !bytes.Equal(sel.typ, []byte("*")) && !bytes.Equal(sel.typ, tag.Name) {
|
|
return false
|
|
}
|
|
for _, attr := range sel.attrs {
|
|
if !attr.AppliesTo(tag) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (sel selectorNode) String() string {
|
|
sb := strings.Builder{}
|
|
sb.Write(sel.typ)
|
|
for _, attr := range sel.attrs {
|
|
if bytes.Equal(attr.attr, []byte("id")) && attr.op == '=' {
|
|
sb.WriteByte('#')
|
|
sb.Write(attr.val)
|
|
} else if bytes.Equal(attr.attr, []byte("class")) && attr.op == '~' {
|
|
sb.WriteByte('.')
|
|
sb.Write(attr.val)
|
|
} else {
|
|
sb.WriteByte('[')
|
|
sb.WriteString(attr.String())
|
|
sb.WriteByte(']')
|
|
}
|
|
}
|
|
if sel.op != 0 {
|
|
sb.WriteByte(' ')
|
|
sb.WriteByte(sel.op)
|
|
sb.WriteByte(' ')
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
type token struct {
|
|
tt css.TokenType
|
|
data []byte
|
|
}
|
|
|
|
type selector []selectorNode
|
|
|
|
func ParseSelector(s string) (selector, error) {
|
|
ts := []token{}
|
|
l := css.NewLexer(parse.NewInputString(s))
|
|
for {
|
|
tt, data := l.Next()
|
|
if tt == css.ErrorToken {
|
|
if err := l.Err(); err != io.EOF {
|
|
return selector{}, err
|
|
}
|
|
break
|
|
}
|
|
ts = append(ts, token{
|
|
tt: tt,
|
|
data: data,
|
|
})
|
|
}
|
|
|
|
sel := selector{}
|
|
node := selectorNode{}
|
|
for i := 0; i < len(ts); i++ {
|
|
t := ts[i]
|
|
if 0 < i && (t.tt == css.WhitespaceToken || t.tt == css.DelimToken && t.data[0] == '>') {
|
|
if t.tt == css.DelimToken {
|
|
node.op = '>'
|
|
} else {
|
|
node.op = ' '
|
|
}
|
|
sel = append(sel, node)
|
|
node = selectorNode{}
|
|
} else if t.tt == css.IdentToken || t.tt == css.DelimToken && t.data[0] == '*' {
|
|
node.typ = t.data
|
|
} else if t.tt == css.DelimToken && (t.data[0] == '.' || t.data[0] == '#') && i+1 < len(ts) && ts[i+1].tt == css.IdentToken {
|
|
if t.data[0] == '#' {
|
|
node.attrs = append(node.attrs, attrSelector{op: '=', attr: []byte("id"), val: ts[i+1].data})
|
|
} else {
|
|
node.attrs = append(node.attrs, attrSelector{op: '~', attr: []byte("class"), val: ts[i+1].data})
|
|
}
|
|
i++
|
|
} else if t.tt == css.DelimToken && t.data[0] == '[' && i+2 < len(ts) && ts[i+1].tt == css.IdentToken && ts[i+2].tt == css.DelimToken {
|
|
if ts[i+2].data[0] == ']' {
|
|
node.attrs = append(node.attrs, attrSelector{op: 0, attr: ts[i+1].data})
|
|
i += 2
|
|
} else if i+4 < len(ts) && ts[i+3].tt == css.IdentToken && ts[i+4].tt == css.DelimToken && ts[i+4].data[0] == ']' {
|
|
node.attrs = append(node.attrs, attrSelector{op: ts[i+2].data[0], attr: ts[i+1].data, val: ts[i+3].data})
|
|
i += 4
|
|
}
|
|
}
|
|
}
|
|
sel = append(sel, node)
|
|
return sel, nil
|
|
}
|
|
|
|
func (sels selector) AppliesTo(tag *Tag) bool {
|
|
if len(sels) == 0 {
|
|
return true
|
|
} else if !sels[len(sels)-1].AppliesTo(tag) {
|
|
return false
|
|
}
|
|
|
|
tag = tag.Parent
|
|
isel := len(sels) - 2
|
|
for 0 <= isel && tag != nil {
|
|
switch sels[isel].op {
|
|
case ' ':
|
|
for tag != nil {
|
|
if sels[isel].AppliesTo(tag) {
|
|
break
|
|
}
|
|
tag = tag.Parent
|
|
}
|
|
case '>':
|
|
if !sels[isel].AppliesTo(tag) {
|
|
return false
|
|
}
|
|
tag = tag.Parent
|
|
default:
|
|
return false
|
|
}
|
|
isel--
|
|
}
|
|
return len(sels) != 0 && isel == -1
|
|
}
|
|
|
|
func (sels selector) String() string {
|
|
if len(sels) == 0 {
|
|
return ""
|
|
}
|
|
sb := strings.Builder{}
|
|
for _, sel := range sels {
|
|
sb.WriteString(sel.String())
|
|
}
|
|
return sb.String()[1:]
|
|
}
|
|
|
|
var voidTags = map[string]bool{
|
|
"area": true,
|
|
"base": true,
|
|
"br": true,
|
|
"col": true,
|
|
"embed": true,
|
|
"hr": true,
|
|
"img": true,
|
|
"input": true,
|
|
"link": true,
|
|
"meta": true,
|
|
"source": true,
|
|
"track": true,
|
|
"wbr": true,
|
|
}
|