simple-privacy-tool/privacy/privacy.go

418 lines
8.4 KiB
Go

package privacy
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"golang.org/x/crypto/chacha20poly1305"
"io"
)
type CipherMethodType byte
type KeyGen interface {
json.Marshaler
GenerateKey(password, salt []byte) []byte
}
const (
segmentSizeBytesLen int = 4
Uninitialised CipherMethodType = 0x00
XChaCha20Simple CipherMethodType = 0x01
AES256GCMSimple CipherMethodType = 0x02
DefaultCipherMethod = XChaCha20Simple
)
var (
ErrInvalidSaltLen = errors.New("invalid salt length")
ErrUninitialisedSalt = errors.New("uninitialised salt")
ErrUninitialisedMethod = errors.New("cipher method type uninitialised")
//ErrCannotReadMagicBytes = errors.New("cannot read magic bytes") //no usage for now
ErrInvalidReadFlow = errors.New("func ReadMagic should be called before calling Read")
ErrInvalidKeyState = errors.New("func GenerateKey should be called first")
ErrInvalidSegmentLength = errors.New("segment length is too long")
segmentLenBytes = make([]byte, segmentSizeBytesLen)
)
type InvalidCipherMethod []byte
func (i InvalidCipherMethod) Error() string {
return "invalid cipher method type"
}
type Reader struct {
*Privacy
reader io.Reader
buf []byte
bufSlice []byte
isEOF bool
}
type WriteCloser struct {
*Privacy
writeCloser io.WriteCloser
buf []byte
bufSlice []byte
magicWritten bool
}
type Privacy struct {
salt []byte
segmentSize uint32
cmType CipherMethodType
aead cipher.AEAD
keygen KeyGen
}
func newPrivacy(k KeyGen) *Privacy {
return &Privacy{
segmentSize: 64 * 1024 * 1024,
cmType: Uninitialised,
keygen: k,
}
}
func NewPrivacyReader(reader io.Reader) *Reader {
return NewPrivacyReaderWithKeyGen(reader, NewArgon2())
}
func NewPrivacyReaderWithKeyGen(reader io.Reader, keygen KeyGen) *Reader {
return &Reader{
Privacy: newPrivacy(keygen),
reader: reader,
isEOF: false,
}
}
func NewPrivacyWriterCloserDefault(wc io.WriteCloser) *WriteCloser {
return NewPrivacyWriteCloser(wc, DefaultCipherMethod)
}
func NewPrivacyWriteCloser(wc io.WriteCloser, cmType CipherMethodType) *WriteCloser {
return NewPrivacyWriteCloserWithKeyGen(wc, cmType, NewArgon2())
}
func NewPrivacyWriteCloserWithKeyGen(wc io.WriteCloser, cmType CipherMethodType, keygen KeyGen) *WriteCloser {
privacy := newPrivacy(keygen)
privacy.cmType = cmType
return &WriteCloser{
Privacy: privacy,
writeCloser: wc,
magicWritten: false,
}
}
func (p *Privacy) SetSalt(salt []byte) error {
if len(salt) != 16 {
return ErrInvalidSaltLen
}
if len(p.salt) != 16 {
p.salt = make([]byte, 16)
}
copy(p.salt, salt)
return nil
}
func (p *Privacy) GetSegmentSize() uint32 {
return p.segmentSize
}
func (p *Privacy) SetSegmentSize(size uint32) {
p.segmentSize = size
}
func (p *Privacy) NewSalt() error {
if len(p.salt) != 16 {
p.salt = make([]byte, 16)
}
if p.cmType == Uninitialised {
return ErrUninitialisedMethod
}
p.salt[0] = byte(p.cmType)
_, err := rand.Read(p.salt[1:])
if err != nil {
return err
}
return nil
}
func (p *Privacy) GenerateKey(passphrase string) error {
var (
key []byte
err error
)
if p.cmType == Uninitialised {
return ErrUninitialisedMethod
}
if len(p.salt) != 16 {
return ErrUninitialisedSalt
}
key = p.keygen.GenerateKey([]byte(passphrase), p.salt)
switch p.cmType {
case XChaCha20Simple:
if p.aead, err = chacha20poly1305.NewX(key); err != nil {
return err
}
case AES256GCMSimple:
var block cipher.Block
if block, err = aes.NewCipher(key); err != nil {
return err
}
if p.aead, err = cipher.NewGCM(block); err != nil {
return err
}
default:
return InvalidCipherMethod([]byte{})
}
return nil
}
func (wc *WriteCloser) Write(b []byte) (n int, err error) {
var (
copied int
nonceSize int
lastMarker int
plaintext []byte
)
if wc.aead == nil {
return 0, ErrInvalidKeyState
}
if cap(wc.buf) != int(wc.segmentSize)+wc.aead.NonceSize()+wc.aead.Overhead() {
wc.buf = make([]byte, int(wc.segmentSize)+wc.aead.NonceSize()+wc.aead.Overhead())
wc.bufSlice = wc.buf[wc.aead.NonceSize():wc.aead.NonceSize()]
}
if !wc.magicWritten {
n, err = wc.writeUp(wc.salt)
if err != nil {
return
}
wc.magicWritten = true
}
nonceSize = wc.aead.NonceSize()
copied = 0
for copied < len(b) {
if len(wc.bufSlice) == int(wc.segmentSize) {
n, err = wc.writeSegment()
if err != nil {
return
}
} else {
lastMarker = len(wc.bufSlice)
plaintext = wc.buf[nonceSize : nonceSize+len(wc.bufSlice)]
if len(b[copied:]) <= int(wc.segmentSize)-len(wc.bufSlice) {
plaintext = plaintext[:len(plaintext)+len(b[copied:])]
copied += copy(plaintext[lastMarker:], b[copied:])
} else {
plaintext = plaintext[:int(wc.segmentSize)]
copied += copy(plaintext[lastMarker:], b[copied:])
}
wc.bufSlice = plaintext
}
}
return copied, nil
}
func (wc *WriteCloser) writeSegment() (n int, err error) {
var (
nonce []byte
ciphertext []byte
plaintext []byte
written int
)
written = len(wc.bufSlice)
binary.LittleEndian.PutUint32(segmentLenBytes, uint32(written))
n, err = wc.writeUp(segmentLenBytes)
if err != nil {
return
}
nonce = wc.buf[:wc.aead.NonceSize()]
_, err = rand.Read(nonce)
if err != nil {
return
}
plaintext = wc.buf[wc.aead.NonceSize() : wc.aead.NonceSize()+written]
ciphertext = plaintext[:0]
wc.aead.Seal(ciphertext, nonce, plaintext, segmentLenBytes)
n, err = wc.writeUp(wc.buf[:written+wc.aead.NonceSize()+wc.aead.Overhead()])
if err != nil {
return
}
wc.bufSlice = wc.buf[wc.aead.NonceSize():wc.aead.NonceSize()]
return written, nil
}
func (wc *WriteCloser) writeUp(b []byte) (n int, err error) {
var (
total int
)
for {
if n, err = wc.writeCloser.Write(b[total:]); err != nil {
return n + total, err
}
total += n
if total == len(b) {
break
}
}
return total, nil
}
func (wc *WriteCloser) Close() (err error) {
if len(wc.bufSlice) > 0 {
_, err = wc.writeSegment()
if err != nil {
return
}
}
return wc.writeCloser.Close()
}
func (r *Reader) ReadMagic() (err error) {
if r.cmType == Uninitialised {
magic := make([]byte, 16)
if _, err = r.readUp(magic); err != nil {
return
}
switch CipherMethodType(magic[0]) {
case XChaCha20Simple:
r.cmType = XChaCha20Simple
case AES256GCMSimple:
r.cmType = AES256GCMSimple
default:
return InvalidCipherMethod(magic)
}
if err = r.SetSalt(magic); err != nil {
return
}
}
return nil
}
func (r *Reader) readUp(b []byte) (n int, err error) {
var (
total int
)
for {
if n, err = r.reader.Read(b[total:]); err != nil {
return n + total, err
}
total += n
if total == len(b) {
break
}
}
return total, nil
}
func (r *Reader) Read(b []byte) (n int, err error) {
var (
segmentLen uint32
nonce []byte
ciphertext []byte
plaintext []byte
copied int
)
if r.cmType == Uninitialised {
return 0, ErrInvalidReadFlow
}
if r.aead == nil {
return 0, ErrInvalidKeyState
}
if r.isEOF {
return 0, io.EOF
}
if cap(r.buf) != int(r.segmentSize)+r.aead.Overhead()+r.aead.NonceSize() {
r.buf = make([]byte, int(r.segmentSize)+r.aead.Overhead()+r.aead.NonceSize())
}
//log.Printf("Read for %d bytes\n", len(b))
copied = 0
for copied < len(b) {
if len(r.bufSlice) == 0 {
n, err = r.readUp(segmentLenBytes)
if err != nil {
if err == io.EOF {
if copied > 0 {
r.isEOF = true
return copied, nil
} else {
r.isEOF = true
return 0, err
}
}
return
}
segmentLen = binary.LittleEndian.Uint32(segmentLenBytes)
if segmentLen > r.segmentSize {
return 0, ErrInvalidSegmentLength
}
n, err = r.readUp(r.buf[:int(segmentLen)+r.aead.Overhead()+r.aead.NonceSize()])
if err != nil {
return
}
nonce = r.buf[:r.aead.NonceSize()]
ciphertext = r.buf[r.aead.NonceSize() : r.aead.NonceSize()+int(segmentLen)+r.aead.Overhead()]
plaintext = ciphertext[:0]
if _, err = r.aead.Open(plaintext, nonce, ciphertext, segmentLenBytes); err != nil {
return copied, fmt.Errorf("decrypt Read: %w", err)
}
plaintext = plaintext[:int(segmentLen)]
r.bufSlice = plaintext
} else {
if len(b[copied:]) <= len(r.bufSlice) {
cp := copy(b[copied:], r.bufSlice)
r.bufSlice = r.bufSlice[cp:]
copied += cp
} else {
copied += copy(b[copied:], r.bufSlice)
r.bufSlice = r.buf[r.aead.NonceSize():r.aead.NonceSize()]
}
}
}
n = copied
return
}