MLIR 23.0.0git
Transforms.h
Go to the documentation of this file.
1//=- Transforms.h - X86 Dialect Transformation Entrypoints --------*- C++ -*-=//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_X86_TRANSFORMS_H
10#define MLIR_DIALECT_X86_TRANSFORMS_H
11
12#include "mlir/IR/Value.h"
13
14namespace mlir {
15
20
21namespace x86 {
22
23/// Helper class to factor out the creation and extraction of masks from nibs.
24struct MaskHelper {
25 /// b0 captures the lowest bit, b7 captures the highest bit.
26 /// Meant to be used with instructions such as mm256BlendPs.
27 template <uint8_t b0, uint8_t b1, uint8_t b2, uint8_t b3, uint8_t b4,
28 uint8_t b5, uint8_t b6, uint8_t b7>
29 static uint8_t blend() {
30 static_assert(b0 <= 1 && b1 <= 1 && b2 <= 1 && b3 <= 1, "overflow");
31 static_assert(b4 <= 1 && b5 <= 1 && b6 <= 1 && b7 <= 1, "overflow");
32 return static_cast<uint8_t>((b7 << 7) | (b6 << 6) | (b5 << 5) | (b4 << 4) |
33 (b3 << 3) | (b2 << 2) | (b1 << 1) | b0);
34 }
35 /// b0 captures the lowest bit, b7 captures the highest bit.
36 /// Meant to be used with instructions such as mm256BlendPs.
37 static void extractBlend(uint8_t mask, uint8_t &b0, uint8_t &b1, uint8_t &b2,
38 uint8_t &b3, uint8_t &b4, uint8_t &b5, uint8_t &b6,
39 uint8_t &b7) {
40 b7 = mask & (1 << 7);
41 b6 = mask & (1 << 6);
42 b5 = mask & (1 << 5);
43 b4 = mask & (1 << 4);
44 b3 = mask & (1 << 3);
45 b2 = mask & (1 << 2);
46 b1 = mask & (1 << 1);
47 b0 = mask & 1;
48 }
49 /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
50 /// Meant to be used with instructions such as mm256ShufflePs.
51 template <unsigned b67, unsigned b45, unsigned b23, unsigned b01>
52 static uint8_t shuffle() {
53 static_assert(b01 <= 0x03, "overflow");
54 static_assert(b23 <= 0x03, "overflow");
55 static_assert(b45 <= 0x03, "overflow");
56 static_assert(b67 <= 0x03, "overflow");
57 return static_cast<uint8_t>((b67 << 6) | (b45 << 4) | (b23 << 2) | b01);
58 }
59 /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
60 static void extractShuffle(uint8_t mask, uint8_t &b01, uint8_t &b23,
61 uint8_t &b45, uint8_t &b67) {
62 b67 = (mask & (0x03 << 6)) >> 6;
63 b45 = (mask & (0x03 << 4)) >> 4;
64 b23 = (mask & (0x03 << 2)) >> 2;
65 b01 = mask & 0x03;
66 }
67 /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
68 /// Meant to be used with instructions such as mm256Permute2f128Ps.
69 template <unsigned b47, unsigned b03>
70 static uint8_t permute() {
71 static_assert(b03 <= 0x0f, "overflow");
72 static_assert(b47 <= 0x0f, "overflow");
73 return static_cast<uint8_t>((b47 << 4) + b03);
74 }
75 /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
76 static void extractPermute(uint8_t mask, uint8_t &b03, uint8_t &b47) {
77 b47 = (mask & (0x0f << 4)) >> 4;
78 b03 = mask & 0x0f;
79 }
80};
81
82//===----------------------------------------------------------------------===//
83
84// A set of patterns for specialized lowering of vector contraction
85// operation to vector fused multiply and add (FMA) operation.
87
88// A set of patterns for lowering 32-bit packed vector contraction operations
89// to their corresponding packed-type dot-product operations, ultimately
90// targeting the relevant x86 LLVM intrinsics (e.g., BF16 and Int8).
92 RewritePatternSet &patterns);
93
94// A set of patterns for lowering 32-bit packed BF16 vector contraction
95// operations to vector fused multiply-add (FMA) operations, following
96// the emulation-based approach using BF16 packed operations.
98
99// Performs forward scheduling of vector producer ops to minimize their live
100// range by placing them at their earliest legal use site.
102
103// Shuffles FMAs with x86 operations as operands such that FMAs are
104// grouped with respect to odd/even packed index.
106
107// A set of patterns for lowering 32-bit packed vector contraction operations
108// to their corresponding packed-type tiled dot-product operations, using
109// AMX ultimately targeting the relevant x86 LLVM intrinsics (e.g., BF16 and
110// Int8).
112
113//===----------------------------------------------------------------------===//
114/// Helpers extracted from:
115/// - clang/lib/Headers/avxintrin.h
116/// - clang/test/CodeGen/X86/avx-builtins.c
117/// - clang/test/CodeGen/X86/avx2-builtins.c
118/// - clang/test/CodeGen/X86/avx-shuffle-builtins.c
119/// as well as the Intel Intrinsics Guide
120/// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html)
121/// make it easier to just implement known good lowerings.
122/// All intrinsics correspond 1-1 to the Intel definition.
123//===----------------------------------------------------------------------===//
124
125namespace avx2 {
126
127namespace inline_asm {
128//===----------------------------------------------------------------------===//
129/// Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.
130//===----------------------------------------------------------------------===//
131/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
133 uint8_t mask);
134
135} // namespace inline_asm
136
137namespace intrin {
138//===----------------------------------------------------------------------===//
139/// Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
140//===----------------------------------------------------------------------===//
141/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
143
144/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
146
147/// a a b b a a b b
148/// Take an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
149/// 0:127 | 128:255
150/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
151Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
152
153// imm[0:1] out of imm[0:3] is:
154// 0 1 2 3
155// a[0:127] or a[128:255] or b[0:127] or b[128:255] |
156// a[0:127] or a[128:255] or b[0:127] or b[128:255]
157// 0 1 2 3
158// imm[0:1] out of imm[4:7].
160 uint8_t mask);
161
162/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
163Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
164} // namespace intrin
165
166//===----------------------------------------------------------------------===//
167/// Generic lowerings may either use intrin or inline_asm depending on needs.
168//===----------------------------------------------------------------------===//
169/// 4x8xf32-specific AVX2 transpose lowering.
171
172/// 8x8xf32-specific AVX2 transpose lowering.
174
175/// Structure to control the behavior of specialized AVX2 transpose lowering.
177 bool lower4x8xf32_ = false;
179 lower4x8xf32_ = lower;
180 return *this;
181 }
182 bool lower8x8xf32_ = false;
184 lower8x8xf32_ = lower;
185 return *this;
186 }
187};
188
189/// Options for controlling specialized AVX2 lowerings.
191 /// Configure specialized vector lowerings.
197};
198
199/// Insert specialized transpose lowering patterns.
201 RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(),
202 int benefit = 10);
203
204} // namespace avx2
205} // namespace x86
206
207/// Collect a set of patterns to lower X86 ops to ops that map to LLVM
208/// intrinsics.
209void populateX86LegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
210 RewritePatternSet &patterns);
211
212/// Configure the target to support lowering X86 ops to ops that map to
213/// LLVM intrinsics.
214void configureX86LegalizeForExportTarget(LLVMConversionTarget &target);
215
216/// Register LLVM conversion interface for X86 dialect.
217void registerConvertX86ToLLVMInterface(DialectRegistry &registry);
218
219} // namespace mlir
220
221#endif // MLIR_DIALECT_X86_TRANSFORMS_H
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
If bit i of mask is zero, take f32@i from v1 else take it from v2.
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
a a b b a a b b Take an 8 bit mask, 2 bit for each position of a[0, 3) and b[0, 4): 0:127 | 128:255 b...
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Helpers extracted from:
Definition Transforms.h:125
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
Generic lowerings may either use intrin or inline_asm depending on needs.
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
8x8xf32-specific AVX2 transpose lowering.
void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)
Insert specialized transpose lowering patterns.
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns)
void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns)
void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns)
void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns)
void populateVectorContractToFMAPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
void configureX86LegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering X86 ops to ops that map to LLVM intrinsics.
void populateX86LegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower X86 ops to ops that map to LLVM intrinsics.
void registerConvertX86ToLLVMInterface(DialectRegistry &registry)
Register LLVM conversion interface for X86 dialect.
Helper class to factor out the creation and extraction of masks from nibs.
Definition Transforms.h:24
static void extractPermute(uint8_t mask, uint8_t &b03, uint8_t &b47)
b03 captures the lower 4 bits, b47 captures the higher 4 bits.
Definition Transforms.h:76
static uint8_t blend()
b0 captures the lowest bit, b7 captures the highest bit.
Definition Transforms.h:29
static uint8_t shuffle()
b01 captures the lower 2 bits, b67 captures the higher 2 bits.
Definition Transforms.h:52
static uint8_t permute()
b03 captures the lower 4 bits, b47 captures the higher 4 bits.
Definition Transforms.h:70
static void extractBlend(uint8_t mask, uint8_t &b0, uint8_t &b1, uint8_t &b2, uint8_t &b3, uint8_t &b4, uint8_t &b5, uint8_t &b6, uint8_t &b7)
b0 captures the lowest bit, b7 captures the highest bit.
Definition Transforms.h:37
static void extractShuffle(uint8_t mask, uint8_t &b01, uint8_t &b23, uint8_t &b45, uint8_t &b67)
b01 captures the lower 2 bits, b67 captures the higher 2 bits.
Definition Transforms.h:60
Options for controlling specialized AVX2 lowerings.
Definition Transforms.h:190
TransposeLoweringOptions transposeOptions
Configure specialized vector lowerings.
Definition Transforms.h:192
LoweringOptions & setTransposeOptions(TransposeLoweringOptions options)
Definition Transforms.h:193
Structure to control the behavior of specialized AVX2 transpose lowering.
Definition Transforms.h:176
TransposeLoweringOptions & lower8x8xf32(bool lower=true)
Definition Transforms.h:183
TransposeLoweringOptions & lower4x8xf32(bool lower=true)
Definition Transforms.h:178