115 lines
1.8 KiB
Go
115 lines
1.8 KiB
Go
package avl
|
|
|
|
type Comparable interface {
|
|
Compare(c Comparable) int
|
|
}
|
|
|
|
type Node struct {
|
|
key Comparable
|
|
height int
|
|
left, right *Node
|
|
}
|
|
|
|
func NewNode(key Comparable) *Node {
|
|
return &Node{
|
|
key: key,
|
|
height: 1,
|
|
}
|
|
}
|
|
|
|
func Insert(node *Node, key Comparable) *Node {
|
|
if node == nil {
|
|
return NewNode(key)
|
|
}
|
|
|
|
if key.Compare(node.key) < 0 {
|
|
node.left = Insert(node.left, key)
|
|
} else if key.Compare(node.key) > 0 {
|
|
node.right = Insert(node.right, key)
|
|
} else {
|
|
panic("duplicate key") // duplicate keys not allowed
|
|
}
|
|
|
|
node.height = 1 + max(height(node.left), height(node.right))
|
|
|
|
balance := getBalance(node)
|
|
|
|
// left-left case
|
|
if balance > 1 && key.Compare(node.left.key) < 0 {
|
|
return rightRotate(node)
|
|
}
|
|
|
|
// right-right case
|
|
if balance < -1 && key.Compare(node.right.key) > 0 {
|
|
return leftRotate(node)
|
|
}
|
|
|
|
// left-right case
|
|
if balance > 1 && key.Compare(node.left.key) > 0 {
|
|
node.left = leftRotate(node.left)
|
|
return rightRotate(node)
|
|
}
|
|
|
|
// right-left case
|
|
if balance < -1 && key.Compare(node.right.key) < 0 {
|
|
node.right = rightRotate(node.right)
|
|
return leftRotate(node)
|
|
}
|
|
|
|
return node
|
|
}
|
|
|
|
func max(a, b int) int {
|
|
if a > b {
|
|
return a
|
|
} else {
|
|
return b
|
|
}
|
|
}
|
|
|
|
func height(node *Node) int {
|
|
if node == nil {
|
|
return 0
|
|
}
|
|
|
|
return node.height
|
|
}
|
|
|
|
func getBalance(node *Node) int {
|
|
if node == nil {
|
|
return 0
|
|
}
|
|
|
|
return height(node.left) - height(node.right)
|
|
}
|
|
|
|
func rightRotate(y *Node) *Node {
|
|
var (
|
|
x *Node = y.left
|
|
T2 *Node = x.right
|
|
)
|
|
|
|
x.right = y
|
|
y.left = T2
|
|
|
|
y.height = max(height(y.left), height(y.right)) + 1
|
|
x.height = max(height(x.left), height(x.right)) + 1
|
|
|
|
return x
|
|
}
|
|
|
|
func leftRotate(x *Node) *Node {
|
|
var (
|
|
y *Node = x.right
|
|
T2 *Node = y.left
|
|
)
|
|
|
|
y.left = x
|
|
x.right = T2
|
|
|
|
x.height = max(height(x.left), height(x.right)) + 1
|
|
y.height = max(height(y.left), height(y.right)) + 1
|
|
|
|
return y
|
|
}
|