avl/avl.go
2023-08-07 13:03:35 +10:00

115 lines
1.8 KiB
Go

package avl
type Comparable interface {
Compare(c Comparable) int
}
type Node struct {
key Comparable
height int
left, right *Node
}
func NewNode(key Comparable) *Node {
return &Node{
key: key,
height: 1,
}
}
func Insert(node *Node, key Comparable) *Node {
if node == nil {
return NewNode(key)
}
if key.Compare(node.key) < 0 {
node.left = Insert(node.left, key)
} else if key.Compare(node.key) > 0 {
node.right = Insert(node.right, key)
} else {
panic("duplicate key") // duplicate keys not allowed
}
node.height = 1 + max(height(node.left), height(node.right))
balance := getBalance(node)
// left-left case
if balance > 1 && key.Compare(node.left.key) < 0 {
return rightRotate(node)
}
// right-right case
if balance < -1 && key.Compare(node.right.key) > 0 {
return leftRotate(node)
}
// left-right case
if balance > 1 && key.Compare(node.left.key) > 0 {
node.left = leftRotate(node.left)
return rightRotate(node)
}
// right-left case
if balance < -1 && key.Compare(node.right.key) < 0 {
node.right = rightRotate(node.right)
return leftRotate(node)
}
return node
}
func max(a, b int) int {
if a > b {
return a
} else {
return b
}
}
func height(node *Node) int {
if node == nil {
return 0
}
return node.height
}
func getBalance(node *Node) int {
if node == nil {
return 0
}
return height(node.left) - height(node.right)
}
func rightRotate(y *Node) *Node {
var (
x *Node = y.left
T2 *Node = x.right
)
x.right = y
y.left = T2
y.height = max(height(y.left), height(y.right)) + 1
x.height = max(height(x.left), height(x.right)) + 1
return x
}
func leftRotate(x *Node) *Node {
var (
y *Node = x.right
T2 *Node = y.left
)
y.left = x
x.right = T2
x.height = max(height(x.left), height(x.right)) + 1
y.height = max(height(y.left), height(y.right)) + 1
return y
}