Skip to content

Commit

Permalink
Merge pull request #20 from jhoward-lm/transactional-client
Browse files Browse the repository at this point in the history
fix: use transactional client
  • Loading branch information
cpanato authored Jun 27, 2024
2 parents b39139e + 92b037e commit 36526f0
Showing 1 changed file with 57 additions and 23 deletions.
80 changes: 57 additions & 23 deletions backends/ent/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,31 @@ func (backend *Backend) Store(doc *sbom.Document, opts *storage.StoreOptions) er
return fmt.Errorf("%w", errInvalidEntOptions)
}

err := backend.client.Document.Create().
tx, err := backend.client.Tx(backend.ctx)
if err != nil {
return fmt.Errorf("creating transactional client: %w", err)
}

backend.ctx = ent.NewTxContext(backend.ctx, tx)

if err := tx.Document.Create().
SetID(doc.Metadata.Id).
OnConflict().
Ignore().
Exec(backend.ctx)
if err != nil && !ent.IsConstraintError(err) {
return fmt.Errorf("ent.Document: %w", err)
Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) {
return rollback(tx, fmt.Errorf("ent.Document: %w", err))
}

if err := backend.saveMetadata(doc.Metadata); err != nil {
return err
return rollback(tx, err)
}

if err := backend.saveNodeList(doc.NodeList); err != nil {
return err
return rollback(tx, err)
}

if err := tx.Commit(); err != nil {
return rollback(tx, err)
}

return nil
Expand All @@ -71,10 +81,12 @@ func (backend *Backend) saveDocumentTypes(docTypes []*sbom.DocumentType) error {
return fmt.Errorf("%w", errUninitializedClient)
}

tx := ent.TxFromContext(backend.ctx)

for _, dt := range docTypes {
typeName := documenttype.Type(dt.Type.String())

newDocType := backend.client.DocumentType.Create().
newDocType := tx.DocumentType.Create().
SetNillableType(&typeName).
SetNillableName(dt.Name).
SetNillableDescription(dt.Description)
Expand All @@ -97,9 +109,11 @@ func (backend *Backend) saveEdges(edges []*sbom.Edge) error {
return fmt.Errorf("%w", errUninitializedClient)
}

tx := ent.TxFromContext(backend.ctx)

for _, edge := range edges {
for _, toID := range edge.To {
newEdgeType := backend.client.EdgeType.Create().
newEdgeType := tx.EdgeType.Create().
SetType(edgetype.Type(edge.Type.String())).
SetFromID(edge.From).
SetToID(toID)
Expand All @@ -119,8 +133,10 @@ func (backend *Backend) saveExternalReferences(refs []*sbom.ExternalReference) e
return fmt.Errorf("%w", errUninitializedClient)
}

tx := ent.TxFromContext(backend.ctx)

for _, ref := range refs {
newRef := backend.client.ExternalReference.Create().
newRef := tx.ExternalReference.Create().
SetURL(ref.Url).
SetComment(ref.Comment).
SetAuthority(ref.Authority).
Expand Down Expand Up @@ -150,12 +166,13 @@ func (backend *Backend) saveHashesEntries(hashes map[int32]string) error {
return fmt.Errorf("%w", errUninitializedClient)
}

tx := ent.TxFromContext(backend.ctx)
entries := []*ent.HashesEntryCreate{}

for alg, content := range hashes {
algName := sbom.HashAlgorithm(alg).String()

entry := backend.client.HashesEntry.Create().
entry := tx.HashesEntry.Create().
SetHashAlgorithmType(hashesentry.HashAlgorithmType(algName)).
SetHashData(content)

Expand All @@ -170,7 +187,7 @@ func (backend *Backend) saveHashesEntries(hashes map[int32]string) error {
entries = append(entries, entry)
}

if err := backend.client.HashesEntry.CreateBulk(entries...).
if err := tx.HashesEntry.CreateBulk(entries...).
Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) {
return fmt.Errorf("ent.HashesEntry: %w", err)
}
Expand All @@ -183,12 +200,13 @@ func (backend *Backend) saveIdentifiersEntries(idents map[int32]string) error {
return fmt.Errorf("%w", errUninitializedClient)
}

tx := ent.TxFromContext(backend.ctx)
entries := []*ent.IdentifiersEntryCreate{}

for typ, value := range idents {
typeName := sbom.SoftwareIdentifierType(typ).String()

entry := backend.client.IdentifiersEntry.Create().
entry := tx.IdentifiersEntry.Create().
SetSoftwareIdentifierType(identifiersentry.SoftwareIdentifierType(typeName)).
SetSoftwareIdentifierValue(value)

Expand All @@ -199,7 +217,7 @@ func (backend *Backend) saveIdentifiersEntries(idents map[int32]string) error {
entries = append(entries, entry)
}

if err := backend.client.IdentifiersEntry.CreateBulk(entries...).
if err := tx.IdentifiersEntry.CreateBulk(entries...).
Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) {
return fmt.Errorf("ent.IdentifiersEntry: %w", err)
}
Expand All @@ -212,7 +230,9 @@ func (backend *Backend) saveMetadata(md *sbom.Metadata) error {
return fmt.Errorf("%w", errUninitializedClient)
}

newMetadata := backend.client.Metadata.Create().
tx := ent.TxFromContext(backend.ctx)

newMetadata := tx.Metadata.Create().
SetID(md.Id).
SetDocumentID(md.Id).
SetVersion(md.Version).
Expand Down Expand Up @@ -247,7 +267,8 @@ func (backend *Backend) saveNodeList(nodeList *sbom.NodeList) error {
return fmt.Errorf("%w", errUninitializedClient)
}

newNodeList := backend.client.NodeList.Create().
tx := ent.TxFromContext(backend.ctx)
newNodeList := tx.NodeList.Create().
SetRootElements(nodeList.RootElements)

if documentID, ok := backend.ctx.Value(metadataIDKey{}).(string); ok {
Expand Down Expand Up @@ -281,8 +302,7 @@ func (backend *Backend) saveNodes(nodes []*sbom.Node) error { //nolint:cyclop
for _, n := range nodes {
newNode := backend.newNodeCreate(n)

err := newNode.OnConflict().Ignore().Exec(backend.ctx)
if err != nil && !ent.IsConstraintError(err) {
if err := newNode.OnConflict().Ignore().Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) {
return fmt.Errorf("ent.Node: %w", err)
}

Expand Down Expand Up @@ -321,8 +341,10 @@ func (backend *Backend) savePersons(persons []*sbom.Person) error {
return fmt.Errorf("%w", errUninitializedClient)
}

tx := ent.TxFromContext(backend.ctx)

for _, p := range persons {
newPerson := backend.client.Person.Create().
newPerson := tx.Person.Create().
SetName(p.Name).
SetEmail(p.Email).
SetIsOrg(p.IsOrg).
Expand Down Expand Up @@ -357,10 +379,11 @@ func (backend *Backend) savePurposes(purposes []sbom.Purpose) error {
return fmt.Errorf("%w", errUninitializedClient)
}

tx := ent.TxFromContext(backend.ctx)
builders := []*ent.PurposeCreate{}

for idx := range purposes {
newPurpose := backend.client.Purpose.Create().
newPurpose := tx.Purpose.Create().
SetPrimaryPurpose(purpose.PrimaryPurpose(purposes[idx].String()))

if nodeID, ok := backend.ctx.Value(nodeIDKey{}).(string); ok {
Expand All @@ -370,7 +393,7 @@ func (backend *Backend) savePurposes(purposes []sbom.Purpose) error {
builders = append(builders, newPurpose)
}

err := backend.client.Purpose.CreateBulk(builders...).
err := tx.Purpose.CreateBulk(builders...).
OnConflict().
Ignore().
Exec(backend.ctx)
Expand All @@ -386,10 +409,11 @@ func (backend *Backend) saveTools(tools []*sbom.Tool) error {
return fmt.Errorf("%w", errUninitializedClient)
}

tx := ent.TxFromContext(backend.ctx)
builders := []*ent.ToolCreate{}

for _, t := range tools {
newTool := backend.client.Tool.Create().
newTool := tx.Tool.Create().
SetName(t.Name).
SetVersion(t.Version).
SetVendor(t.Vendor)
Expand All @@ -401,7 +425,7 @@ func (backend *Backend) saveTools(tools []*sbom.Tool) error {
builders = append(builders, newTool)
}

err := backend.client.Tool.CreateBulk(builders...).
err := tx.Tool.CreateBulk(builders...).
OnConflict().
Ignore().
Exec(backend.ctx)
Expand All @@ -413,7 +437,9 @@ func (backend *Backend) saveTools(tools []*sbom.Tool) error {
}

func (backend *Backend) newNodeCreate(n *sbom.Node) *ent.NodeCreate {
newNode := backend.client.Node.Create().
tx := ent.TxFromContext(backend.ctx)

newNode := tx.Node.Create().
SetID(n.Id).
SetAttribution(n.Attribution).
SetBuildDate(n.BuildDate.AsTime()).
Expand Down Expand Up @@ -441,3 +467,11 @@ func (backend *Backend) newNodeCreate(n *sbom.Node) *ent.NodeCreate {

return newNode
}

func rollback(tx *ent.Tx, err error) error {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rolling back transaction: %w", rollbackErr)
}

return err
}

0 comments on commit 36526f0

Please sign in to comment.