Skip to content

Commit

Permalink
neo4j updates
Browse files Browse the repository at this point in the history
  • Loading branch information
caffix committed Dec 13, 2024
1 parent 626b130 commit 1949bd9
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 10 deletions.
60 changes: 50 additions & 10 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@
package assetdb

import (
"context"
"embed"
"fmt"
"math/rand"
"net/url"
"strings"
"time"

"github.com/glebarez/sqlite"
neo4jdb "github.com/neo4j/neo4j-go-driver/v5/neo4j"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/config"
neomigrations "github.com/owasp-amass/asset-db/migrations/neo4j"
pgmigrations "github.com/owasp-amass/asset-db/migrations/postgres"
sqlitemigrations "github.com/owasp-amass/asset-db/migrations/sqlite3"
"github.com/owasp-amass/asset-db/repository"

Check failure on line 22 in db.go

View workflow job for this annotation

GitHub Actions / lint

could not import github.com/owasp-amass/asset-db/repository (-: # github.com/owasp-amass/asset-db/repository
Expand Down Expand Up @@ -37,23 +44,20 @@ func New(dbtype, dsn string) (repository.Repository, error) {
}

func migrateDatabase(dbtype, dsn string) error {
var name string
var fs embed.FS
var database gorm.Dialector

switch dbtype {
case sqlrepo.SQLite:
fallthrough
case sqlrepo.SQLiteMemory:
name = "sqlite3"
fs = sqlitemigrations.Migrations()
database = sqlite.Open(dsn)
return sqlMigrate("sqlite3", sqlite.Open(dsn), sqlitemigrations.Migrations())
case sqlrepo.Postgres:
name = "postgres"
fs = pgmigrations.Migrations()
database = postgres.Open(dsn)
return sqlMigrate("postgres", postgres.Open(dsn), pgmigrations.Migrations())
case neo4jdb.Neo4j:

Check failure on line 54 in db.go

View workflow job for this annotation

GitHub Actions / lint

undefined: neo4jdb.Neo4j (typecheck)
return neoMigrate(dsn)
}
return nil
}

func sqlMigrate(name string, database gorm.Dialector, fs embed.FS) error {
sql, err := gorm.Open(database, &gorm.Config{})
if err != nil {
return err
Expand All @@ -76,3 +80,39 @@ func migrateDatabase(dbtype, dsn string) error {
}
return nil
}

func neoMigrate(dsn string) error {
u, err := url.Parse(dsn)
if err != nil {
return err
}

auth := neo4jdb.NoAuth()
var username, password string
if u.User != nil {
username = u.User.Username()
password, _ = u.User.Password()
auth = neo4jdb.BasicAuth(username, password, "")
}
dbname := strings.TrimPrefix(u.Path, "/")

newdsn := fmt.Sprintf("%s://%s", u.Scheme, u.Host)
driver, err := neo4jdb.NewDriverWithContext(newdsn, auth, func(cfg *config.Config) {
cfg.MaxConnectionPoolSize = 20
cfg.MaxConnectionLifetime = time.Hour
cfg.ConnectionLivenessCheckTimeout = 10 * time.Minute
})
if err != nil {
return err
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

if err := driver.VerifyConnectivity(ctx); err != nil {
return err
}
defer driver.Close(context.Background())

return neomigrations.InitializeSchema(driver, dbname)
}
137 changes: 137 additions & 0 deletions migrations/neo4j/schema.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Copyright © by Jeff Foley 2017-2024. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
// SPDX-License-Identifier: Apache-2.0

package neo4j

import (
"context"

neo4jdb "github.com/neo4j/neo4j-go-driver/v5/neo4j"
)

func InitializeSchema(driver neo4jdb.DriverWithContext, dbname string) error {
_ = executeQuery(driver, dbname, "CREATE DATABASE "+dbname+" IF NOT EXISTS")
_ = executeQuery(driver, dbname, "START DATABASE "+dbname+" WAIT 10 SECONDS")

err := executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_entities_entity_id IF NOT EXISTS FOR (n:Entity) REQUIRE n.entity_id IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE INDEX entities_range_index_updated_at IF NOT EXISTS FOR (n:Entity) ON (n.updated_at)")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_enttag_tag_id IF NOT EXISTS FOR (n:EntityTag) REQUIRE n.tag_id IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE INDEX enttag_range_index_updated_at IF NOT EXISTS FOR (n:EntityTag) ON (n.updated_at)")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE INDEX enttag_range_index_entity_id IF NOT EXISTS FOR (n:EntityTag) ON (n.entity_id)")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_edgetag_tag_id IF NOT EXISTS FOR (n:EdgeTag) REQUIRE n.tag_id IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE INDEX edgetag_range_index_updated_at IF NOT EXISTS FOR (n:EdgeTag) ON (n.updated_at)")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE INDEX edgetag_range_index_edge_id IF NOT EXISTS FOR (n:EdgeTag) ON (n.edge_id)")
if err != nil {
return err
}

return entitiesContentIndexes(driver, dbname)
}

func entitiesContentIndexes(driver neo4jdb.DriverWithContext, dbname string) error {
err := executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_autnum_content_handle IF NOT EXISTS FOR (n:AutnumRecord) REQUIRE n.handle IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_autnum_content_number IF NOT EXISTS FOR (n:AutnumRecord) REQUIRE n.number IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_autsys_content_number IF NOT EXISTS FOR (n:AutonomousSystem) REQUIRE n.number IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_domainrec_content_domain IF NOT EXISTS FOR (n:DomainRecord) REQUIRE n.domain IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_email_content_address IF NOT EXISTS FOR (n:EmailAddress) REQUIRE n.address IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_fqdn_content_name IF NOT EXISTS FOR (n:FQDN) REQUIRE n.name IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_ipaddr_content_address IF NOT EXISTS FOR (n:IPAddress) REQUIRE n.address IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_ipnetrec_content_cidr IF NOT EXISTS FOR (n:IPNetRecord) REQUIRE n.cidr IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_ipnetrec_content_handle IF NOT EXISTS FOR (n:IPNetRecord) REQUIRE n.handle IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_netblock_content_cidr IF NOT EXISTS FOR (n:Netblock) REQUIRE n.cidr IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_org_content_name IF NOT EXISTS FOR (n:Organization) REQUIRE n.name IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE INDEX person_range_index_full_name IF NOT EXISTS FOR (n:Person) ON (n.full_name)")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_tls_content_serial_number IF NOT EXISTS FOR (n:TLSCertificate) REQUIRE n.serial_number IS UNIQUE")
if err != nil {
return err
}

err = executeQuery(driver, dbname, "CREATE CONSTRAINT constraint_url_content_url IF NOT EXISTS FOR (n:URL) REQUIRE n.url IS UNIQUE")
if err != nil {
return err
}
return nil
}

func executeQuery(driver neo4jdb.DriverWithContext, dbname, query string) error {
_, err := neo4jdb.ExecuteQuery(context.Background(), driver,
query, nil, neo4jdb.EagerResultTransformer, neo4jdb.ExecuteQueryWithDatabase(dbname))
return err
}
34 changes: 34 additions & 0 deletions repository/neo4j/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//go:build integration

// Copyright © by Jeff Foley 2017-2024. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
// SPDX-License-Identifier: Apache-2.0

package neo4j

import (
"fmt"
"os"
"testing"

neomigrations "github.com/owasp-amass/asset-db/migrations/neo4j"
)

var store *neoRepository

func TestMain(m *testing.M) {
dsn := "bolt://neo4j:hackme4fun@localhost:7687/assetdb"

store, err := New("neo4j", dsn)
if err != nil {
fmt.Println(err)
return
}

if err := neomigrations.InitializeSchema(store.db, store.dbname); err != nil {
fmt.Println(err)
return
}

os.Exit(m.Run())
}

0 comments on commit 1949bd9

Please sign in to comment.