MLIR  20.0.0git
Transforms.h
Go to the documentation of this file.
1 //=- Transforms.h - X86Vector Dialect Transformation Entrypoints -*- C++ -*-=//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
10 #define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
11 
12 #include "mlir/IR/Value.h"
13 
14 namespace mlir {
15 
16 class ImplicitLocOpBuilder;
17 class LLVMConversionTarget;
18 class LLVMTypeConverter;
19 class RewritePatternSet;
20 
21 namespace x86vector {
22 
23 /// Helper class to factor out the creation and extraction of masks from nibs.
24 struct MaskHelper {
25  /// b0 captures the lowest bit, b7 captures the highest bit.
26  /// Meant to be used with instructions such as mm256BlendPs.
27  template <uint8_t b0, uint8_t b1, uint8_t b2, uint8_t b3, uint8_t b4,
28  uint8_t b5, uint8_t b6, uint8_t b7>
29  static uint8_t blend() {
30  static_assert(b0 <= 1 && b1 <= 1 && b2 <= 1 && b3 <= 1, "overflow");
31  static_assert(b4 <= 1 && b5 <= 1 && b6 <= 1 && b7 <= 1, "overflow");
32  return static_cast<uint8_t>((b7 << 7) | (b6 << 6) | (b5 << 5) | (b4 << 4) |
33  (b3 << 3) | (b2 << 2) | (b1 << 1) | b0);
34  }
35  /// b0 captures the lowest bit, b7 captures the highest bit.
36  /// Meant to be used with instructions such as mm256BlendPs.
37  static void extractBlend(uint8_t mask, uint8_t &b0, uint8_t &b1, uint8_t &b2,
38  uint8_t &b3, uint8_t &b4, uint8_t &b5, uint8_t &b6,
39  uint8_t &b7) {
40  b7 = mask & (1 << 7);
41  b6 = mask & (1 << 6);
42  b5 = mask & (1 << 5);
43  b4 = mask & (1 << 4);
44  b3 = mask & (1 << 3);
45  b2 = mask & (1 << 2);
46  b1 = mask & (1 << 1);
47  b0 = mask & 1;
48  }
49  /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
50  /// Meant to be used with instructions such as mm256ShufflePs.
51  template <unsigned b67, unsigned b45, unsigned b23, unsigned b01>
52  static uint8_t shuffle() {
53  static_assert(b01 <= 0x03, "overflow");
54  static_assert(b23 <= 0x03, "overflow");
55  static_assert(b45 <= 0x03, "overflow");
56  static_assert(b67 <= 0x03, "overflow");
57  return static_cast<uint8_t>((b67 << 6) | (b45 << 4) | (b23 << 2) | b01);
58  }
59  /// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
60  static void extractShuffle(uint8_t mask, uint8_t &b01, uint8_t &b23,
61  uint8_t &b45, uint8_t &b67) {
62  b67 = (mask & (0x03 << 6)) >> 6;
63  b45 = (mask & (0x03 << 4)) >> 4;
64  b23 = (mask & (0x03 << 2)) >> 2;
65  b01 = mask & 0x03;
66  }
67  /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
68  /// Meant to be used with instructions such as mm256Permute2f128Ps.
69  template <unsigned b47, unsigned b03>
70  static uint8_t permute() {
71  static_assert(b03 <= 0x0f, "overflow");
72  static_assert(b47 <= 0x0f, "overflow");
73  return static_cast<uint8_t>((b47 << 4) + b03);
74  }
75  /// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
76  static void extractPermute(uint8_t mask, uint8_t &b03, uint8_t &b47) {
77  b47 = (mask & (0x0f << 4)) >> 4;
78  b03 = mask & 0x0f;
79  }
80 };
81 
82 //===----------------------------------------------------------------------===//
83 /// Helpers extracted from:
84 /// - clang/lib/Headers/avxintrin.h
85 /// - clang/test/CodeGen/X86/avx-builtins.c
86 /// - clang/test/CodeGen/X86/avx2-builtins.c
87 /// - clang/test/CodeGen/X86/avx-shuffle-builtins.c
88 /// as well as the Intel Intrinsics Guide
89 /// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html)
90 /// make it easier to just implement known good lowerings.
91 /// All intrinsics correspond 1-1 to the Intel definition.
92 //===----------------------------------------------------------------------===//
93 
94 namespace avx2 {
95 
96 namespace inline_asm {
97 //===----------------------------------------------------------------------===//
98 /// Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.
99 //===----------------------------------------------------------------------===//
100 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
102  uint8_t mask);
103 
104 } // namespace inline_asm
105 
106 namespace intrin {
107 //===----------------------------------------------------------------------===//
108 /// Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
109 //===----------------------------------------------------------------------===//
110 /// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
112 
113 /// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
115 
116 /// a a b b a a b b
117 /// Take an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
118 /// 0:127 | 128:255
119 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
120 Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
121 
122 // imm[0:1] out of imm[0:3] is:
123 // 0 1 2 3
124 // a[0:127] or a[128:255] or b[0:127] or b[128:255] |
125 // a[0:127] or a[128:255] or b[0:127] or b[128:255]
126 // 0 1 2 3
127 // imm[0:1] out of imm[4:7].
129  uint8_t mask);
130 
131 /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
132 Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
133 } // namespace intrin
134 
135 //===----------------------------------------------------------------------===//
136 /// Generic lowerings may either use intrin or inline_asm depending on needs.
137 //===----------------------------------------------------------------------===//
138 /// 4x8xf32-specific AVX2 transpose lowering.
140 
141 /// 8x8xf32-specific AVX2 transpose lowering.
143 
144 /// Structure to control the behavior of specialized AVX2 transpose lowering.
146  bool lower4x8xf32_ = false;
148  lower4x8xf32_ = lower;
149  return *this;
150  }
151  bool lower8x8xf32_ = false;
153  lower8x8xf32_ = lower;
154  return *this;
155  }
156 };
157 
158 /// Options for controlling specialized AVX2 lowerings.
160  /// Configure specialized vector lowerings.
164  return *this;
165  }
166 };
167 
168 /// Insert specialized transpose lowering patterns.
170  RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(),
171  int benefit = 10);
172 
173 } // namespace avx2
174 } // namespace x86vector
175 
176 /// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM
177 /// intrinsics.
179  const LLVMTypeConverter &converter, RewritePatternSet &patterns);
180 
181 /// Configure the target to support lowering X86Vector ops to ops that map to
182 /// LLVM intrinsics.
183 void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target);
184 
185 } // namespace mlir
186 
187 #endif // MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
static llvm::ManagedStatic< PassManagerOptions > options
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
If bit i of mask is zero, take f32@i from v1 else take it from v2.
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)
a a b b a a b b Take an 8 bit mask, 2 bit for each position of a[0, 3) and b[0, 4): 0:127 | 128:255 b...
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2)
Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
8x8xf32-specific AVX2 transpose lowering.
void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)
Insert specialized transpose lowering patterns.
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)
Generic lowerings may either use intrin or inline_asm depending on needs.
Include the generated interface declarations.
void populateX86VectorLegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower X86Vector ops to ops that map to LLVM intrinsics.
const FrozenRewritePatternSet & patterns
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering X86Vector ops to ops that map to LLVM intrinsics.
Helper class to factor out the creation and extraction of masks from nibs.
Definition: Transforms.h:24
static void extractBlend(uint8_t mask, uint8_t &b0, uint8_t &b1, uint8_t &b2, uint8_t &b3, uint8_t &b4, uint8_t &b5, uint8_t &b6, uint8_t &b7)
b0 captures the lowest bit, b7 captures the highest bit.
Definition: Transforms.h:37
static void extractPermute(uint8_t mask, uint8_t &b03, uint8_t &b47)
b03 captures the lower 4 bits, b47 captures the higher 4 bits.
Definition: Transforms.h:76
static void extractShuffle(uint8_t mask, uint8_t &b01, uint8_t &b23, uint8_t &b45, uint8_t &b67)
b01 captures the lower 2 bits, b67 captures the higher 2 bits.
Definition: Transforms.h:60
static uint8_t shuffle()
b01 captures the lower 2 bits, b67 captures the higher 2 bits.
Definition: Transforms.h:52
static uint8_t blend()
b0 captures the lowest bit, b7 captures the highest bit.
Definition: Transforms.h:29
static uint8_t permute()
b03 captures the lower 4 bits, b47 captures the higher 4 bits.
Definition: Transforms.h:70
Options for controlling specialized AVX2 lowerings.
Definition: Transforms.h:159
LoweringOptions & setTransposeOptions(TransposeLoweringOptions options)
Definition: Transforms.h:162
TransposeLoweringOptions transposeOptions
Configure specialized vector lowerings.
Definition: Transforms.h:161
Structure to control the behavior of specialized AVX2 transpose lowering.
Definition: Transforms.h:145
TransposeLoweringOptions & lower8x8xf32(bool lower=true)
Definition: Transforms.h:152
TransposeLoweringOptions & lower4x8xf32(bool lower=true)
Definition: Transforms.h:147