Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support generating code for a type graph #27

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type User = {

`zodToTs()` and `createTypeAlias()` return a TS AST nodes, so if you want to get the node as a string, you can use the `printNode()` utility.

`zodToTs()`:
### `zodToTs()`

```ts
import { printNode, zodToTs } from 'zod-to-ts'
Expand All @@ -93,7 +93,7 @@ result:
}"
```

`createTypeAlias()`:
### `createTypeAlias()`

```ts
import { createTypeAlias, printNode, zodToTs } from 'zod-to-ts'
Expand All @@ -118,6 +118,42 @@ result:
}"
```

### `zodToTsMultiple`

```ts
const address = z.object({
addressLine1: z.string(),
addressLine2: z.string()
})

const customer = z.object({
name: z.string(),
age: z.number(),
addresses: z.array(address),
})

const zodtoTsResult = zodToTsMultiple({
Customer: customer,
Address: address,
})

const tsSourceText = zodtoTsResult.typeAliases.map(ta => printNode(ta)).join("\n")
```

result:

```ts
type Customer = {
name: string;
age: number;
addresses: Address[];
};
type Address = {
addressLine1: string;
addressLine2: string;
};
```

## Overriding Types

You can use `withGetType` to override a type, which is useful when more information is needed to determine the actual type. Unfortunately, this means working with the TS AST:
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"typescript": "4.9.4",
"vite": "4.0.3",
"vitest": "0.26.2",
"zod": "3.20.2"
"zod": "3.20.6"
},
"sideEffects": false,
"tsup": {
Expand Down
8 changes: 4 additions & 4 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

159 changes: 157 additions & 2 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import ts from 'typescript'
import { ZodTypeAny } from 'zod'
import ts, {NodeArray, TypeAliasDeclaration} from 'typescript'
import {
AnyZodObject,
z,
ZodArray,
ZodDiscriminatedUnion,
ZodEffects, ZodFunction, ZodIntersection, ZodMap,
ZodRecord, ZodSet,
ZodTuple,
ZodTypeAny,
ZodUnion
} from 'zod'
import {
GetType,
GetTypeFunction,
Expand All @@ -18,6 +28,8 @@ import {
maybeIdentifierToTypeReference,
printNode,
} from './utils'
import {withGetType} from "./utils";
import {AnyZodTuple, ZodUnionOptions} from "zod/lib/types";

const { factory: f } = ts

Expand Down Expand Up @@ -364,6 +376,149 @@ const zodToTsNode = (
return f.createKeywordTypeNode(ts.SyntaxKind.AnyKeyword)
}


const recursiveWithGetType = (zod: z.ZodTypeAny, identifiersByType: Map<z.ZodTypeAny, ts.Identifier>) => {

if (identifiersByType.has(zod)) {
withGetType(zod, () => identifiersByType.get(zod)!)
}

const {typeName} = zod._def;
switch (typeName) {
case 'ZodEnum':
case 'ZodNativeEnum':
case 'ZodString':
case 'ZodNumber':
case 'ZodBigInt':
case 'ZodBoolean':
case 'ZodDate':
case 'ZodUndefined':
case 'ZodNull':
case 'ZodVoid':
case 'ZodAny':
case 'ZodUnknown':
case 'ZodNever':
case 'ZodLiteral': {
// Nothing to do for above
break;
}
case 'ZodLazy': {
break
}
case 'ZodArray': {
recursiveWithGetType((zod as ZodArray<ZodTypeAny>).element, identifiersByType)
break;
}
case 'ZodUnion':{
const zdu = zod as ZodUnion<ZodUnionOptions>
for (const option of zdu.options) {
recursiveWithGetType(option, identifiersByType)
}
break;
}
case 'ZodDiscriminatedUnion': {
const zdu = zod as ZodDiscriminatedUnion<string, AnyZodObject[]>
for (const option of zdu.options) {
recursiveWithGetType(option, identifiersByType)
}
break;
}
case 'ZodEffects': {
recursiveWithGetType((zod as ZodEffects<ZodTypeAny>).innerType(), identifiersByType)
break;
}
case 'ZodTuple': {
const zodTuple = zod as ZodTuple
for (const item of zodTuple.items) {
recursiveWithGetType(item, identifiersByType)
}
break;
}
case 'ZodRecord': {
const zodRecord = zod as ZodRecord
recursiveWithGetType(zodRecord.keySchema, identifiersByType)
recursiveWithGetType(zodRecord.valueSchema, identifiersByType)
break
}
case 'ZodMap':{
const zodMap = zod as ZodMap
recursiveWithGetType(zodMap._def.keyType, identifiersByType)
recursiveWithGetType(zodMap._def.valueType, identifiersByType)
break
}
case 'ZodSet': {
recursiveWithGetType((zod as ZodSet)._def.valueType, identifiersByType)
break
}
case 'ZodIntersection': {
const intersection = zod as ZodIntersection<ZodTypeAny, ZodTypeAny>
recursiveWithGetType(intersection._def.left, identifiersByType)
recursiveWithGetType(intersection._def.right, identifiersByType)
break
}
case 'ZodBranded':
case 'ZodNullable':
case 'ZodOptional':
case 'ZodDefault':
case 'ZodPromise': {
if ('unwrap' in zod && typeof zod.unwrap === 'function') {
recursiveWithGetType(zod.unwrap(), identifiersByType)
}
break;
}
case 'ZodFunction': {
const zodFunction = (zod as ZodFunction<AnyZodTuple, ZodTypeAny>)
recursiveWithGetType(zodFunction.parameters(), identifiersByType)
recursiveWithGetType(zodFunction.returnType(), identifiersByType)
break;
}
case 'ZodObject': {
const properties = Object.entries((zod as AnyZodObject).shape)
for (const [_key, value] of properties) {
if (value instanceof z.ZodType) {
recursiveWithGetType(value, identifiersByType)
}
}
break;
}
}
}

export const zodToTsMultiple = (objectGraph: {[identifier: string]: z.ZodTypeAny}, options?: ZodToTsOptions): {
typeAliases: NodeArray<TypeAliasDeclaration>
store: ZodToTsStore
} => {
const typesByIdentifier = new Map(Object.entries(objectGraph).map(([key, value]) => [f.createIdentifier(key), value]))
const identifiersByType = new Map<z.ZodTypeAny, ts.Identifier>([...typesByIdentifier.entries()].map(([key, value]) => ([value, key])))

// traverse object graph, find references, wrap them with `withGetType`
const zodTypes = [...typesByIdentifier.values()]
for (const zodType of zodTypes) {
recursiveWithGetType(zodType, identifiersByType)
}

const results = [...typesByIdentifier.entries()].map(([identifier, zodType]) => {
// Must not print alias for root level types, otherwise all we will print is aliases
// Clone the object and remove any getType function
const clone: ZodTypeAny & {getType?: GetTypeFunction} = Object.assign(Object.create(Object.getPrototypeOf(zodType)), zodType)
delete clone['getType']
const {node, store} = zodToTs(clone, undefined, options)
const typeAlias = f.createTypeAliasDeclaration(undefined, identifier, undefined, node)
return {typeAlias, store}
})


return {
typeAliases: f.createNodeArray(results.map(result => result.typeAlias)),
store: {
// eslint-disable-next-line unicorn/no-array-reduce
nativeEnums: results.map(result => result.store.nativeEnums).reduce((previous, current) => [...previous, ...current])
}
}
}



export { createTypeAlias, printNode }
export { withGetType } from './utils'
export type { GetType, ZodToTsOptions }
22 changes: 22 additions & 0 deletions test/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,25 @@ import ts from 'typescript'
import { printNode } from '../src'

export const printNodeTest = (node: ts.Node) => printNode(node, { newLine: ts.NewLineKind.LineFeed })

/**
* Removes indentation from inline multiline strings
* Useful for unit tests
*
* An alternative to Vi.JestAssertion.toMatchInlineSnapshot,
* because it adds annoying quotes to the snapshot text
*/
export const stripIndent = (indented: string): string => {
const lines = indented.split("\n")
// eslint-disable-next-line unicorn/no-array-reduce
const commonIndent = lines.reduce((accumulator, line) => {
if(/^\s*$/.test(line)) return accumulator
const whiteSpaceMatch = /^\s*/.exec(line)
const lineIndentLength = whiteSpaceMatch ? whiteSpaceMatch[0].length : 0;
return Math.min(accumulator, lineIndentLength)
}, Number.MAX_SAFE_INTEGER)

const withoutIndent = lines.map(line => line.slice(commonIndent)).join("\n")
// remove one leading and trailing newline, but no more than one
return withoutIndent.replace(/^\n/, "").replace(/\n$/,"")
}
Loading