From c68f5c09eef09da719d7aa9d18ed1413a696b23f Mon Sep 17 00:00:00 2001 From: Quan Anh Mai Date: Fri, 20 Dec 2024 17:31:39 +0700 Subject: [PATCH] Implement Vector API rearrange operation --- .../compiler/asm/amd64/AMD64Assembler.java | 10 + .../lir/aarch64/AArch64PermuteOp.java | 61 +++++ .../lir/amd64/vector/AMD64VectorShuffle.java | 252 ++++++++++++++++++ 3 files changed, 323 insertions(+) diff --git a/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/asm/amd64/AMD64Assembler.java b/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/asm/amd64/AMD64Assembler.java index d2e6ce7a250a..76525f0cb3a2 100644 --- a/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/asm/amd64/AMD64Assembler.java +++ b/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/asm/amd64/AMD64Assembler.java @@ -2327,7 +2327,11 @@ public static class VexRVMOp extends VexOp { public static final VexRVMOp VSQRTSD = new VexRVMOp("VSQRTSD", VEXPrefixConfig.P_F2, VEXPrefixConfig.M_0F, VEXPrefixConfig.WIG, 0x51, VEXOpAssertion.AVX1_AVX512F_128, EVEXTuple.T1S_64BIT, VEXPrefixConfig.W1); public static final VexRVMOp VSQRTSS = new VexRVMOp("VSQRTSS", VEXPrefixConfig.P_F3, VEXPrefixConfig.M_0F, VEXPrefixConfig.WIG, 0x51, VEXOpAssertion.AVX1_AVX512F_128, EVEXTuple.T1S_32BIT, VEXPrefixConfig.W0); + public static final VexRVMOp VPERMILPS = new VexRVMOp("VPERMILPS", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W0, 0x0C, VEXOpAssertion.AVX1_AVX512F_VL, EVEXTuple.FVM, VEXPrefixConfig.W0); public static final VexRVMOp VPERMD = new VexRVMOp("VPERMD", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W0, 0x36, VEXOpAssertion.AVX2_AVX512F_VL_256_512, EVEXTuple.FVM, VEXPrefixConfig.W0); + public static final VexRVMOp VPERMPS = new VexRVMOp("VPERMPS", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W0, 0x16, VEXOpAssertion.AVX2_AVX512F_VL_256_512, EVEXTuple.FVM, VEXPrefixConfig.W0); + public static final VexRVMOp VPERMILPD = new VexRVMOp("VPERMILPD", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W0, 0x0D, VEXOpAssertion.AVX1_AVX512F_VL, EVEXTuple.FVM, VEXPrefixConfig.W1); + public static final VexRVMOp VMOVSS = new VexRVMOp("VMOVSS", VEXPrefixConfig.P_F3, VEXPrefixConfig.M_0F, VEXPrefixConfig.WIG, 0x10, VEXOpAssertion.AVX1_AVX512F_128, EVEXTuple.T1S_32BIT, VEXPrefixConfig.W0); public static final VexRVMOp VMOVSD = new VexRVMOp("VMOVSD", VEXPrefixConfig.P_F2, VEXPrefixConfig.M_0F, VEXPrefixConfig.WIG, 0x10, VEXOpAssertion.AVX1_AVX512F_128, EVEXTuple.T1S_64BIT, VEXPrefixConfig.W1); public static final VexRVMOp VMOVHPD = new VexRVMOp("VMOVHPD", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F, VEXPrefixConfig.WIG, 0x16, VEXOpAssertion.AVX1_AVX512F_128, EVEXTuple.T1S_64BIT, VEXPrefixConfig.W1); @@ -2431,8 +2435,14 @@ public static class VexRVMOp extends VexOp { public static final VexRVMOp EVSQRTSD = new VexRVMOp("EVSQRTSD", VSQRTSD); public static final VexRVMOp EVSQRTSS = new VexRVMOp("EVSQRTSS", VSQRTSS); + public static final VexRVMOp EVPERMB = new VexRVMOp("EVPERMB", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W0, 0x8D, VEXOpAssertion.AVX512_VBMI_VL, EVEXTuple.FVM, VEXPrefixConfig.W0, true); public static final VexRVMOp EVPERMW = new VexRVMOp("EVPERMW", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W1, 0x8D, VEXOpAssertion.AVX512BW_VL, EVEXTuple.FVM, VEXPrefixConfig.W1, true); + public static final VexRVMOp EVPERMILPS = new VexRVMOp("EVPERMILPS", VPERMILPS); public static final VexRVMOp EVPERMD = new VexRVMOp("EVPERMD", VPERMD); + public static final VexRVMOp EVPERMPS = new VexRVMOp("EVPERMPS", VPERMPS); + public static final VexRVMOp EVPERMILPD = new VexRVMOp("EVPERMILPD", VPERMILPD); + public static final VexRVMOp EVPERMQ = new VexRVMOp("EVPERMQ", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W1, 0x36, VEXOpAssertion.AVX512F_VL_256_512, EVEXTuple.FVM, VEXPrefixConfig.W1, true); + public static final VexRVMOp EVPERMPD = new VexRVMOp("EVPERMPD", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W1, 0x16, VEXOpAssertion.AVX512F_VL_256_512, EVEXTuple.FVM, VEXPrefixConfig.W1, true); public static final VexRVMOp EVPBLENDMB = new VexRVMOp("EVPBLENDMB", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W0, 0x66, VEXOpAssertion.AVX512BW_VL, EVEXTuple.FVM, VEXPrefixConfig.W0, true); public static final VexRVMOp EVPBLENDMW = new VexRVMOp("EVPBLENDMW", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W1, 0x66, VEXOpAssertion.AVX512BW_VL, EVEXTuple.FVM, VEXPrefixConfig.W1, true); diff --git a/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/aarch64/AArch64PermuteOp.java b/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/aarch64/AArch64PermuteOp.java index 4075611d38be..47f25b8ec206 100644 --- a/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/aarch64/AArch64PermuteOp.java +++ b/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/aarch64/AArch64PermuteOp.java @@ -30,13 +30,17 @@ import jdk.graal.compiler.asm.aarch64.AArch64ASIMDAssembler.ASIMDSize; import jdk.graal.compiler.asm.aarch64.AArch64ASIMDAssembler.ElementSize; import jdk.graal.compiler.asm.aarch64.AArch64MacroAssembler; +import jdk.graal.compiler.core.common.LIRKind; import jdk.graal.compiler.debug.GraalError; import jdk.graal.compiler.lir.asm.CompilationResultBuilder; import jdk.graal.compiler.lir.LIRInstructionClass; import jdk.graal.compiler.lir.Opcode; +import jdk.graal.compiler.lir.gen.LIRGeneratorTool; +import jdk.vm.ci.aarch64.AArch64Kind; import jdk.vm.ci.code.Register; import jdk.vm.ci.meta.AllocatableValue; +import jdk.vm.ci.meta.Value; /** * This enum encapsulates AArch64 instructions which perform permutations. @@ -102,4 +106,61 @@ public void emitCode(CompilationResultBuilder crb, AArch64MacroAssembler masm) { } } + + public static class ASIMDPermuteOp extends AArch64LIRInstruction { + private static final LIRInstructionClass TYPE = LIRInstructionClass.create(ASIMDPermuteOp.class); + + @Def protected AllocatableValue result; + @Alive protected AllocatableValue source; + @Use protected AllocatableValue indices; + @Temp({OperandFlag.REG, OperandFlag.ILLEGAL}) protected AllocatableValue xtmp1; + @Temp({OperandFlag.REG, OperandFlag.ILLEGAL}) protected AllocatableValue xtmp2; + + public ASIMDPermuteOp(LIRGeneratorTool tool, AllocatableValue result, AllocatableValue source, AllocatableValue indices) { + super(TYPE); + this.result = result; + this.source = source; + this.indices = indices; + AArch64Kind eKind = ((AArch64Kind) result.getPlatformKind()).getScalar(); + this.xtmp1 = eKind == AArch64Kind.BYTE ? Value.ILLEGAL : tool.newVariable(LIRKind.value(AArch64Kind.V128_BYTE)); + this.xtmp2 = eKind == AArch64Kind.BYTE ? Value.ILLEGAL : tool.newVariable(LIRKind.value(AArch64Kind.V128_BYTE)); + } + + @Override + public void emitCode(CompilationResultBuilder crb, AArch64MacroAssembler masm) { + AArch64Kind vKind = (AArch64Kind) result.getPlatformKind(); + AArch64Kind eKind = vKind.getScalar(); + ASIMDSize vSize = ASIMDSize.fromVectorKind(vKind); + Register xtmp1Reg = xtmp1.equals(Value.ILLEGAL) ? Register.None : asRegister(xtmp1); + Register xtmp2Reg = xtmp2.equals(Value.ILLEGAL) ? Register.None : asRegister(xtmp2); + Register currentIdxReg = asRegister(indices); + // Since NEON only supports byte look up, we repeatedly convert a 2W-bit look up into + // W-bit look up by transforming a 2W-bit index with value v into a pair of W-bit + // indices v * 2, v * 2 + 1 until we reach the element width equal to Byte.SIZE + if (eKind.getSizeInBytes() == AArch64Kind.QWORD.getSizeInBytes()) { + masm.neon.shlVVI(vSize, ElementSize.DoubleWord, xtmp1Reg, currentIdxReg, 1); + masm.neon.shlVVI(vSize, ElementSize.DoubleWord, xtmp2Reg, xtmp1Reg, Integer.SIZE); + masm.neon.orrVVV(vSize, xtmp1Reg, xtmp1Reg, xtmp2Reg); + masm.neon.orrVI(vSize, ElementSize.DoubleWord, xtmp1Reg, 1L << Integer.SIZE); + currentIdxReg = xtmp1Reg; + eKind = AArch64Kind.DWORD; + } + if (eKind.getSizeInBytes() == AArch64Kind.DWORD.getSizeInBytes()) { + masm.neon.shlVVI(vSize, ElementSize.Word, xtmp1Reg, currentIdxReg, 1); + masm.neon.shlVVI(vSize, ElementSize.Word, xtmp2Reg, xtmp1Reg, Short.SIZE); + masm.neon.orrVVV(vSize, xtmp1Reg, xtmp1Reg, xtmp2Reg); + masm.neon.orrVI(vSize, ElementSize.Word, xtmp1Reg, 1 << Short.SIZE); + currentIdxReg = xtmp1Reg; + eKind = AArch64Kind.WORD; + } + if (eKind.getSizeInBytes() == AArch64Kind.WORD.getSizeInBytes()) { + masm.neon.shlVVI(vSize, ElementSize.HalfWord, xtmp1Reg, currentIdxReg, 1); + masm.neon.shlVVI(vSize, ElementSize.HalfWord, xtmp2Reg, xtmp1Reg, Byte.SIZE); + masm.neon.orrVVV(vSize, xtmp1Reg, xtmp1Reg, xtmp2Reg); + masm.neon.orrVI(vSize, ElementSize.HalfWord, xtmp1Reg, 1 << Byte.SIZE); + currentIdxReg = xtmp1Reg; + } + masm.neon.tblVVV(vSize, asRegister(result), asRegister(source), currentIdxReg); + } + } } diff --git a/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/amd64/vector/AMD64VectorShuffle.java b/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/amd64/vector/AMD64VectorShuffle.java index e5c98aad8132..adfe365cdbad 100644 --- a/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/amd64/vector/AMD64VectorShuffle.java +++ b/compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/amd64/vector/AMD64VectorShuffle.java @@ -63,6 +63,7 @@ import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRVMOp.VMOVSD; import static jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRVMOp.VPSHUFB; import static jdk.graal.compiler.asm.amd64.AVXKind.AVXSize.XMM; +import static jdk.graal.compiler.asm.amd64.AVXKind.AVXSize.YMM; import static jdk.graal.compiler.asm.amd64.AVXKind.AVXSize.ZMM; import static jdk.vm.ci.code.ValueUtil.asRegister; import static jdk.vm.ci.code.ValueUtil.isRegister; @@ -71,13 +72,19 @@ import jdk.graal.compiler.asm.amd64.AMD64Address; import jdk.graal.compiler.asm.amd64.AMD64Assembler; import jdk.graal.compiler.asm.amd64.AMD64Assembler.AMD64SIMDInstructionEncoding; +import jdk.graal.compiler.asm.amd64.AMD64Assembler.VexMoveMaskOp; import jdk.graal.compiler.asm.amd64.AMD64Assembler.VexMRIOp; import jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRMIOp; +import jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRMOp; import jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRVMIOp; +import jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRVMOp; +import jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRVMROp; +import jdk.graal.compiler.asm.amd64.AMD64Assembler.VexShiftOp; import jdk.graal.compiler.asm.amd64.AMD64BaseAssembler; import jdk.graal.compiler.asm.amd64.AMD64MacroAssembler; import jdk.graal.compiler.asm.amd64.AVXKind; import jdk.graal.compiler.asm.amd64.AVXKind.AVXSize; +import jdk.graal.compiler.core.amd64.AMD64LIRGenerator; import jdk.graal.compiler.core.common.LIRKind; import jdk.graal.compiler.debug.Assertions; import jdk.graal.compiler.debug.GraalError; @@ -89,9 +96,254 @@ import jdk.vm.ci.amd64.AMD64Kind; import jdk.vm.ci.code.Register; import jdk.vm.ci.meta.AllocatableValue; +import jdk.vm.ci.meta.JavaConstant; +import jdk.vm.ci.meta.Value; public class AMD64VectorShuffle { + /** + * General purpose permutation, this node looks up elements from a source vector using the index + * vector as the selector. + */ + public static final class PermuteOp extends AMD64LIRInstruction { + public static final LIRInstructionClass TYPE = LIRInstructionClass.create(PermuteOp.class); + + @Def protected AllocatableValue result; + @Use protected AllocatableValue source; + @Use protected AllocatableValue indices; + private final AMD64SIMDInstructionEncoding encoding; + + private PermuteOp(AllocatableValue result, AllocatableValue source, AllocatableValue indices, AMD64SIMDInstructionEncoding encoding) { + super(TYPE); + this.result = result; + this.source = source; + this.indices = indices; + this.encoding = encoding; + } + + public static AMD64LIRInstruction create(AMD64LIRGenerator gen, AllocatableValue result, AllocatableValue source, AllocatableValue indices, AMD64SIMDInstructionEncoding encoding) { + AMD64Kind eKind = ((AMD64Kind) result.getPlatformKind()).getScalar(); + AVXSize avxSize = AVXKind.getRegisterSize(result); + return switch (eKind) { + case BYTE -> { + if (gen.supportsCPUFeature(CPUFeature.AVX512_VBMI) || avxSize == XMM) { + yield new PermuteOp(result, source, indices, encoding); + } else { + yield switch (avxSize) { + case YMM -> new PermuteOpWithTemps(gen, result, source, indices, encoding, 3, false); + case ZMM -> new PermuteOpWithTemps(gen, result, source, indices, encoding, 3, true); + default -> throw GraalError.shouldNotReachHereUnexpectedValue(avxSize); + }; + } + } + case WORD -> { + if (encoding == AMD64SIMDInstructionEncoding.EVEX) { + GraalError.guarantee(gen.supportsCPUFeature(CPUFeature.AVX512BW) && gen.supportsCPUFeature(CPUFeature.AVX512VL), "must support basic avx512"); + yield new PermuteOp(result, source, indices, encoding); + } else { + GraalError.guarantee(avxSize.getBytes() < ZMM.getBytes(), "zmm requires evex"); + yield switch (avxSize) { + case XMM, YMM -> new PermuteOpWithTemps(gen, result, source, indices, encoding, 3, false); + default -> throw GraalError.shouldNotReachHereUnexpectedValue(avxSize); + }; + } + } + case DWORD, SINGLE -> new PermuteOp(result, source, indices, encoding); + case QWORD, DOUBLE -> { + if (encoding == AMD64SIMDInstructionEncoding.EVEX || avxSize != YMM) { + yield new PermuteOp(result, source, indices, encoding); + } else { + yield new PermuteOpWithTemps(gen, result, source, indices, encoding, 2, false); + } + } + default -> throw GraalError.shouldNotReachHereUnexpectedValue(eKind); + }; + } + + @Override + public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) { + AMD64Kind eKind = ((AMD64Kind) result.getPlatformKind()).getScalar(); + AVXSize avxSize = AVXKind.getRegisterSize(result); + switch (eKind) { + case BYTE -> { + if (avxSize == XMM) { + VexRVMOp.VPSHUFB.encoding(encoding).emit(masm, XMM, asRegister(result), asRegister(source), asRegister(indices)); + } else { + VexRVMOp.EVPERMB.encoding(encoding).emit(masm, avxSize, asRegister(result), asRegister(indices), asRegister(source)); + } + } + case WORD -> VexRVMOp.EVPERMW.encoding(encoding).emit(masm, avxSize, asRegister(result), asRegister(indices), asRegister(source)); + case DWORD, SINGLE -> { + if (avxSize.getBytes() <= XMM.getBytes()) { + VexRVMOp.VPERMILPS.encoding(encoding).emit(masm, XMM, asRegister(result), asRegister(source), asRegister(indices)); + } else if (((AMD64Kind) result.getPlatformKind()).getScalar().isInteger()) { + VexRVMOp.VPERMD.encoding(encoding).emit(masm, avxSize, asRegister(result), asRegister(indices), asRegister(source)); + } else { + VexRVMOp.VPERMPS.encoding(encoding).emit(masm, avxSize, asRegister(result), asRegister(indices), asRegister(source)); + } + } + case QWORD, DOUBLE -> { + if (avxSize.getBytes() <= XMM.getBytes()) { + VexRVMOp.VPERMILPD.encoding(encoding).emit(masm, XMM, asRegister(result), asRegister(source), asRegister(indices)); + } else if (((AMD64Kind) result.getPlatformKind()).getScalar().isInteger()) { + VexRVMOp.EVPERMQ.encoding(encoding).emit(masm, avxSize, asRegister(result), asRegister(indices), asRegister(source)); + } else { + VexRVMOp.EVPERMPD.encoding(encoding).emit(masm, avxSize, asRegister(result), asRegister(indices), asRegister(source)); + } + } + default -> throw GraalError.shouldNotReachHereUnexpectedValue(eKind); + } + } + } + + /** + * Similar to {@code PermuteOp}, the difference is that this node may use additional temp + * registers. As a result, it is split out so the inputs of {@code PermuteOp} does not need to + * be {@link jdk.graal.compiler.lir.LIRInstruction.Alive}. + */ + private static final class PermuteOpWithTemps extends AMD64LIRInstruction { + public static final LIRInstructionClass TYPE = LIRInstructionClass.create(PermuteOpWithTemps.class); + + @Def protected AllocatableValue result; + @Alive protected AllocatableValue source; + @Alive protected AllocatableValue indices; + @Temp protected AllocatableValue[] xtmps; + @Temp({OperandFlag.REG, OperandFlag.ILLEGAL}) protected AllocatableValue ktmp; + private final AMD64SIMDInstructionEncoding encoding; + + private PermuteOpWithTemps(AMD64LIRGenerator gen, AllocatableValue result, AllocatableValue source, AllocatableValue indices, AMD64SIMDInstructionEncoding encoding, int xtmpRegs, + boolean ktmpReg) { + super(TYPE); + GraalError.guarantee(xtmpRegs <= 3, "too many temporaries, %d", xtmpRegs); + this.result = result; + this.source = source; + this.indices = indices; + this.xtmps = new AllocatableValue[xtmpRegs]; + for (int i = 0; i < xtmpRegs; i++) { + this.xtmps[i] = gen.newVariable(indices.getValueKind()); + } + this.ktmp = ktmpReg ? gen.newVariable(LIRKind.value(AMD64Kind.MASK64)) : Value.ILLEGAL; + this.encoding = encoding; + } + + @Override + public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) { + AMD64Kind eKind = ((AMD64Kind) result.getPlatformKind()).getScalar(); + AVXSize avxSize = AVXKind.getRegisterSize(result); + switch (eKind) { + case BYTE -> { + GraalError.guarantee(!masm.supports(CPUFeature.AVX512_VBMI) && avxSize.getBytes() > XMM.getBytes(), "should be a PermuteOp"); + emitBytePermute(crb, masm, asRegister(indices)); + } + case WORD -> { + GraalError.guarantee(!masm.supports(CPUFeature.AVX512BW) && avxSize != ZMM, "should be PermuteOp"); + Register indexReg = asRegister(indices); + Register xtmp1Reg = asRegister(xtmps[0]); + Register xtmp2Reg = asRegister(xtmps[1]); + Register xtmp3Reg = asRegister(xtmps[2]); + + // Transform into a byte permute by transforming a 16-bit index with value v + // into a pair of 8-bit indices v * 2, v * 2 + 1 + VexShiftOp.VPSLLW.encoding(encoding).emit(masm, avxSize, xtmp1Reg, indexReg, Byte.SIZE + 1); + AMD64Address inc = (AMD64Address) crb.recordDataReferenceInCode(JavaConstant.forInt(0x01000100), Integer.BYTES); + VexRMOp broadcastOp = masm.supports(CPUFeature.AVX2) ? VexRMOp.VPBROADCASTD : VexRMOp.VBROADCASTSS; + broadcastOp.encoding(encoding).emit(masm, avxSize, xtmp2Reg, inc); + VexRVMOp.VPOR.encoding(encoding).emit(masm, avxSize, xtmp1Reg, xtmp1Reg, xtmp2Reg); + VexShiftOp.VPSLLW.encoding(encoding).emit(masm, avxSize, xtmp2Reg, indexReg, 1); + VexRVMOp.VPOR.encoding(encoding).emit(masm, avxSize, xtmp3Reg, xtmp1Reg, xtmp2Reg); + emitBytePermute(crb, masm, xtmp3Reg); + } + case DWORD, SINGLE -> throw GraalError.shouldNotReachHere("should be PermuteOp"); + case QWORD, DOUBLE -> { + GraalError.guarantee(encoding == AMD64SIMDInstructionEncoding.VEX && avxSize == YMM, "should be PermuteOp"); + Register indexReg = asRegister(indices); + Register xtmp1Reg = asRegister(xtmps[0]); + Register xtmp2Reg = asRegister(xtmps[1]); + + // Transform into an int permute by transforming a 64-bit index with value v + // into a pair of 32-bit indices v + 2, v * 2 + 1 + VexShiftOp.VPSLLQ.encoding(encoding).emit(masm, YMM, xtmp1Reg, indexReg, Integer.SIZE + 1); + AMD64Address inc = (AMD64Address) crb.asLongConstRef(JavaConstant.forLong(1L << Integer.SIZE)); + VexRMOp.VPBROADCASTQ.encoding(encoding).emit(masm, YMM, xtmp2Reg, inc); + VexRVMOp.VPOR.encoding(encoding).emit(masm, YMM, xtmp2Reg, xtmp1Reg, xtmp2Reg); + VexShiftOp.VPSLLQ.encoding(encoding).emit(masm, YMM, xtmp1Reg, indexReg, 1); + VexRVMOp.VPOR.encoding(encoding).emit(masm, YMM, xtmp1Reg, xtmp1Reg, xtmp2Reg); + VexRVMOp op = eKind == AMD64Kind.QWORD ? VexRVMOp.VPERMD : VexRVMOp.VPERMPS; + op.encoding(encoding).emit(masm, YMM, asRegister(result), xtmp1Reg, asRegister(source)); + } + default -> throw GraalError.shouldNotReachHereUnexpectedValue(eKind); + } + } + + private void emitBytePermute(CompilationResultBuilder crb, AMD64MacroAssembler masm, Register indexReg) { + AVXSize avxSize = AVXKind.getRegisterSize(result); + switch (avxSize) { + case XMM -> VexRVMOp.VPSHUFB.encoding(encoding).emit(masm, XMM, asRegister(result), asRegister(source), indexReg); + case YMM -> { + Register sourceReg = asRegister(source); + Register xtmp1Reg = asRegister(xtmps[0]); + Register xtmp2Reg = asRegister(xtmps[1]); + Register xtmp3Reg = asRegister(xtmps[2]); + GraalError.guarantee(!indexReg.equals(xtmp1Reg) && !indexReg.equals(xtmp2Reg), "cannot alias"); + + // Find the elements that are collected from the first YMM half + VexRVMIOp.VPERM2I128.emit(masm, YMM, xtmp1Reg, sourceReg, sourceReg, 0x00); + VexRVMOp.VPSHUFB.encoding(encoding).emit(masm, YMM, xtmp1Reg, xtmp1Reg, indexReg); + + // Find the elements that are collected from the second YMM half + VexRVMIOp.VPERM2I128.emit(masm, YMM, xtmp2Reg, sourceReg, sourceReg, 0x11); + VexRVMOp.VPSHUFB.encoding(encoding).emit(masm, YMM, xtmp2Reg, xtmp2Reg, indexReg); + + // Blend the results, the 5-th bit of the index vector is the selector (0 - 15 + // has the 5-th bit being 0 while 16 - 31 has the 5-bit being 1) + // Shift the 5-th bit to the position of the sign bit to use vpblendvb + VexShiftOp.VPSLLD.encoding(encoding).emit(masm, YMM, xtmp3Reg, indexReg, 3); + VexRVMROp.VPBLENDVB.emit(masm, YMM, asRegister(result), xtmp3Reg, xtmp1Reg, xtmp2Reg); + } + case ZMM -> { + Register sourceReg = asRegister(source); + Register xtmp1Reg = asRegister(xtmps[0]); + Register xtmp2Reg = asRegister(xtmps[1]); + Register xtmp3Reg = asRegister(xtmps[2]); + Register ktmpReg = asRegister(ktmp); + GraalError.guarantee(!indexReg.equals(xtmp1Reg) && !indexReg.equals(xtmp2Reg) && !indexReg.equals(xtmp3Reg), "cannot alias"); + + // Process the even-index elements + // Find the 2-byte location in the source vector and move to the correct 2-byte + // location in the result + VexShiftOp.EVPSRLD.emit(masm, ZMM, xtmp1Reg, indexReg, 1); + VexRVMOp.EVPERMW.emit(masm, ZMM, xtmp1Reg, xtmp1Reg, sourceReg); + + // Elements with indices end with 0 are at the correct position, while the ones + // that have their indices end with 1 need to shift right by 8 + VexShiftOp.EVPSLLD.emit(masm, ZMM, xtmp3Reg, indexReg, Short.SIZE - 1); + VexRMOp.EVPMOVW2M.emit(masm, ZMM, ktmpReg, xtmp3Reg); + VexShiftOp.EVPSRLD.emit(masm, ZMM, xtmp3Reg, xtmp1Reg, Byte.SIZE); + VexRVMOp.EVPBLENDMW.emit(masm, ZMM, xtmp1Reg, xtmp1Reg, xtmp3Reg, ktmpReg); + + // Process the odd-index elements + // Find the 2-byte location in the source vector and move to the correct 2-byte + // location in the result + VexShiftOp.EVPSRLD.emit(masm, ZMM, xtmp2Reg, indexReg, Byte.SIZE + 1); + VexRVMOp.EVPERMW.emit(masm, ZMM, xtmp2Reg, xtmp2Reg, sourceReg); + + // Elements with indices end with 1 are at the correct position, while the ones + // that have their indices end with 0 need to shift left by 8 + VexShiftOp.EVPSLLD.emit(masm, ZMM, xtmp3Reg, indexReg, Byte.SIZE - 1); + VexRMOp.EVPMOVW2M.emit(masm, ZMM, ktmpReg, xtmp3Reg); + VexShiftOp.EVPSLLD.emit(masm, ZMM, xtmp3Reg, xtmp2Reg, Byte.SIZE); + VexRVMOp.EVPBLENDMW.emit(masm, ZMM, xtmp2Reg, xtmp3Reg, xtmp2Reg, ktmpReg); + + // Blend the odd and even index + AMD64Address mask = (AMD64Address) crb.asLongConstRef(JavaConstant.forLong(0x5555555555555555L)); + VexMoveMaskOp.KMOVQ.emit(masm, XMM, ktmpReg, mask); + VexRVMOp.EVPBLENDMB.emit(masm, ZMM, asRegister(result), xtmp2Reg, xtmp1Reg, ktmpReg); + } + default -> throw GraalError.shouldNotReachHereUnexpectedValue(avxSize); + } + } + } + public static final class IntToVectorOp extends AMD64LIRInstruction { public static final LIRInstructionClass TYPE = LIRInstructionClass.create(IntToVectorOp.class);