MLIR  21.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::arm_sve;
22 
31  OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
32  ScalableMaskedAddIIntrOp>;
34  OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
35  ScalableMaskedAddFIntrOp>;
37  OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
38  ScalableMaskedSubIIntrOp>;
40  OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
41  ScalableMaskedSubFIntrOp>;
43  OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
44  ScalableMaskedMulIIntrOp>;
46  OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
47  ScalableMaskedMulFIntrOp>;
49  OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
50  ScalableMaskedSDivIIntrOp>;
52  OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
53  ScalableMaskedUDivIIntrOp>;
55  OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
56  ScalableMaskedDivFIntrOp>;
57 
58 namespace {
59 
60 /// Unrolls a conversion to/from equivalent vector types, to allow using a
61 /// conversion intrinsic that only supports 1-D vector types.
62 ///
63 /// Example:
64 /// ```
65 /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
66 /// ```
67 /// is rewritten into:
68 /// ```
69 /// %cst = arith.constant dense<false> : vector<2x[16]xi1>
70 /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
71 /// %2 = "arm_sve.intr.convert.to.svbool"(%1)
72 /// : (vector<[4]xi1>) -> vector<[16]xi1>
73 /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
74 /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
75 /// %5 = "arm_sve.intr.convert.to.svbool"(%4)
76 /// : (vector<[4]xi1>) -> vector<[16]xi1>
77 /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
78 /// ```
79 template <typename Op, typename IntrOp>
80 struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
82 
83  LogicalResult
84  matchAndRewrite(Op convertOp, typename Op::Adaptor,
85  ConversionPatternRewriter &rewriter) const override {
86  auto loc = convertOp.getLoc();
87 
88  auto source = convertOp.getSource();
89  VectorType sourceType = source.getType();
90  VectorType resultType = convertOp.getResult().getType();
91 
92  Value result = rewriter.create<arith::ConstantOp>(
93  loc, resultType, rewriter.getZeroAttr(resultType));
94 
95  // We want to iterate over the input vector in steps of the trailing
96  // dimension. So this creates tile shape where all leading dimensions are 1,
97  // and the trailing dimension step is the size of the dimension.
98  SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
99  tileShape.back() = sourceType.getShape().back();
100 
101  // Iterate over all scalable mask/predicate slices of the source vector.
102  for (SmallVector<int64_t> index :
103  StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
104  auto extractOrInsertPosition = ArrayRef(index).drop_back();
105  auto sourceVector = rewriter.create<vector::ExtractOp>(
106  loc, source, extractOrInsertPosition);
107  VectorType convertedType =
108  VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
109  .setDim(0, resultType.getShape().back());
110  auto convertedVector =
111  rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
112  result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
113  extractOrInsertPosition);
114  }
115 
116  rewriter.replaceOp(convertOp, result);
117  return success();
118  }
119 };
120 
121 using ConvertToSvboolOpLowering =
122  SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
123 
124 using ConvertFromSvboolOpLowering =
125  SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
126 
129 
130 /// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
131 /// but first input (P1) and result predicates need conversion to/from svbool.
132 struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
134 
135  LogicalResult
136  matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
137  ConversionPatternRewriter &rewriter) const override {
138  auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
139  auto loc = pselOp.getLoc();
140  auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
141  adaptor.getP1());
142  auto indexI32 = rewriter.create<arith::IndexCastOp>(
143  loc, rewriter.getI32Type(), pselOp.getIndex());
144  auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
145  pselOp.getP2(), indexI32);
146  rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
147  pselOp, adaptor.getP1().getType(), pselIntr);
148  return success();
149  }
150 };
151 
152 /// Converts `vector.create_mask` ops that match the size of an SVE predicate
153 /// to the `whilelt` intrinsic. This produces more canonical codegen than the
154 /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
155 /// for more details. Note that we can't use (the more general) active.lane.mask
156 /// as its semantics don't neatly map on to `vector.create_mask`, as it does an
157 /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
158 /// `n` is zero (whereas `create_mask` just returns an all-false mask).
159 struct CreateMaskOpLowering
160  : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
162 
163  LogicalResult
164  matchAndRewrite(vector::CreateMaskOp createMaskOp,
165  vector::CreateMaskOp::Adaptor adaptor,
166  ConversionPatternRewriter &rewriter) const override {
167  auto maskType = createMaskOp.getVectorType();
168  if (maskType.getRank() != 1 || !maskType.isScalable())
169  return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
170 
171  // TODO: Support masks which are multiples of SVE predicates.
172  auto maskBaseSize = maskType.getDimSize(0);
173  if (maskBaseSize < 2 || maskBaseSize > 16 ||
174  !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
175  return rewriter.notifyMatchFailure(createMaskOp,
176  "not SVE predicate-sized");
177 
178  auto loc = createMaskOp.getLoc();
179  auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
180  rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
181  adaptor.getOperands()[0]);
182  return success();
183  }
184 };
185 
186 } // namespace
187 
188 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
190  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
191  // Populate conversion patterns
192 
193  // clang-format off
194  patterns.add<ConvertFromSvboolOpLowering,
195  ConvertToSvboolOpLowering,
197  PselOpLowering,
211  ZipX2OpLowering,
212  ZipX4OpLowering,
213  SdotOpLowering>(converter);
214  // Add vector.create_mask conversion with a high benefit as it produces much
215  // nicer code than the generic lowering.
216  patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
217  // clang-format on
218 }
219 
221  LLVMConversionTarget &target) {
222  // clang-format off
223  target.addLegalOp<BfmmlaOp,
224  ConvertFromSvboolIntrOp,
225  ConvertToSvboolIntrOp,
226  DupQLaneIntrOp,
227  PselIntrOp,
228  ScalableMaskedAddFIntrOp,
229  ScalableMaskedAddIIntrOp,
230  ScalableMaskedDivFIntrOp,
231  ScalableMaskedMulFIntrOp,
232  ScalableMaskedMulIIntrOp,
233  ScalableMaskedSDivIIntrOp,
234  ScalableMaskedSubFIntrOp,
235  ScalableMaskedSubIIntrOp,
236  ScalableMaskedUDivIIntrOp,
237  SmmlaIntrOp,
238  UdotIntrOp,
239  UmmlaIntrOp,
240  UsmmlaIntrOp,
241  WhileLTIntrOp,
242  ZipX2IntrOp,
243  ZipX4IntrOp,
244  SdotIntrOp>();
245  target.addIllegalOp<ConvertFromSvboolOp,
246  ConvertToSvboolOp,
247  DupQLaneOp,
248  PselOp,
249  ScalableMaskedAddFOp,
250  ScalableMaskedAddIOp,
251  ScalableMaskedDivFOp,
252  ScalableMaskedMulFOp,
253  ScalableMaskedMulIOp,
254  ScalableMaskedSDivIOp,
255  ScalableMaskedSubFOp,
256  ScalableMaskedSubIOp,
257  ScalableMaskedUDivIOp,
258  SmmlaOp,
259  UdotOp,
260  UmmlaOp,
261  UsmmlaOp,
262  ZipX2Op,
263  ZipX4Op,
264  SdotOp>();
265  // clang-format on
266 }
OneToOneConvertToLLVMPattern< UmmlaOp, UmmlaIntrOp > UmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubIOp, ScalableMaskedSubIIntrOp > ScalableMaskedSubIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSDivIOp, ScalableMaskedSDivIIntrOp > ScalableMaskedSDivIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddIOp, ScalableMaskedAddIIntrOp > ScalableMaskedAddIOpLowering
OneToOneConvertToLLVMPattern< UsmmlaOp, UsmmlaIntrOp > UsmmlaOpLowering
OneToOneConvertToLLVMPattern< SmmlaOp, SmmlaIntrOp > SmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubFOp, ScalableMaskedSubFIntrOp > ScalableMaskedSubFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedUDivIOp, ScalableMaskedUDivIIntrOp > ScalableMaskedUDivIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedDivFOp, ScalableMaskedDivFIntrOp > ScalableMaskedDivFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulFOp, ScalableMaskedMulFIntrOp > ScalableMaskedMulFOpLowering
OneToOneConvertToLLVMPattern< UdotOp, UdotIntrOp > UdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddFOp, ScalableMaskedAddFIntrOp > ScalableMaskedAddFOpLowering
OneToOneConvertToLLVMPattern< DupQLaneOp, DupQLaneIntrOp > DupQLaneLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulIOp, ScalableMaskedMulIIntrOp > ScalableMaskedMulIOpLowering
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
IntegerType getI1Type()
Definition: Builders.cpp:52
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:199
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:205
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:292
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This provides public APIs that all operations should have.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
Definition: BuiltinTypes.h:320
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateArmSVELegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.