MLIR  22.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::arm_sve;
20 
29  OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
30  ScalableMaskedAddIIntrOp>;
32  OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
33  ScalableMaskedAddFIntrOp>;
35  OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
36  ScalableMaskedSubIIntrOp>;
38  OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
39  ScalableMaskedSubFIntrOp>;
41  OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
42  ScalableMaskedMulIIntrOp>;
44  OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
45  ScalableMaskedMulFIntrOp>;
47  OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
48  ScalableMaskedSDivIIntrOp>;
50  OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
51  ScalableMaskedUDivIIntrOp>;
53  OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
54  ScalableMaskedDivFIntrOp>;
55 
56 namespace {
57 
58 /// Unrolls a conversion to/from equivalent vector types, to allow using a
59 /// conversion intrinsic that only supports 1-D vector types.
60 ///
61 /// Example:
62 /// ```
63 /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
64 /// ```
65 /// is rewritten into:
66 /// ```
67 /// %cst = arith.constant dense<false> : vector<2x[16]xi1>
68 /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
69 /// %2 = "arm_sve.intr.convert.to.svbool"(%1)
70 /// : (vector<[4]xi1>) -> vector<[16]xi1>
71 /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
72 /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
73 /// %5 = "arm_sve.intr.convert.to.svbool"(%4)
74 /// : (vector<[4]xi1>) -> vector<[16]xi1>
75 /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
76 /// ```
77 template <typename Op, typename IntrOp>
78 struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
80 
81  LogicalResult
82  matchAndRewrite(Op convertOp, typename Op::Adaptor,
83  ConversionPatternRewriter &rewriter) const override {
84  auto loc = convertOp.getLoc();
85 
86  auto source = convertOp.getSource();
87  VectorType sourceType = source.getType();
88  VectorType resultType = convertOp.getResult().getType();
89 
90  Value result = arith::ConstantOp::create(rewriter, loc, resultType,
91  rewriter.getZeroAttr(resultType));
92 
93  // We want to iterate over the input vector in steps of the trailing
94  // dimension. So this creates tile shape where all leading dimensions are 1,
95  // and the trailing dimension step is the size of the dimension.
96  SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
97  tileShape.back() = sourceType.getShape().back();
98 
99  // Iterate over all scalable mask/predicate slices of the source vector.
100  for (SmallVector<int64_t> index :
101  StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
102  auto extractOrInsertPosition = ArrayRef(index).drop_back();
103  auto sourceVector = vector::ExtractOp::create(rewriter, loc, source,
104  extractOrInsertPosition);
105  VectorType convertedType =
106  VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
107  .setDim(0, resultType.getShape().back());
108  auto convertedVector =
109  IntrOp::create(rewriter, loc, TypeRange{convertedType}, sourceVector);
110  result = vector::InsertOp::create(rewriter, loc, convertedVector, result,
111  extractOrInsertPosition);
112  }
113 
114  rewriter.replaceOp(convertOp, result);
115  return success();
116  }
117 };
118 
119 using ConvertToSvboolOpLowering =
120  SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
121 
122 using ConvertFromSvboolOpLowering =
123  SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
124 
127 
128 /// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
129 /// but first input (P1) and result predicates need conversion to/from svbool.
130 struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
132 
133  LogicalResult
134  matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
135  ConversionPatternRewriter &rewriter) const override {
136  auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
137  auto loc = pselOp.getLoc();
138  auto svboolP1 = ConvertToSvboolIntrOp::create(rewriter, loc, svboolType,
139  adaptor.getP1());
140  auto indexI32 = arith::IndexCastOp::create(
141  rewriter, loc, rewriter.getI32Type(), pselOp.getIndex());
142  auto pselIntr = PselIntrOp::create(rewriter, loc, svboolType, svboolP1,
143  pselOp.getP2(), indexI32);
144  rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
145  pselOp, adaptor.getP1().getType(), pselIntr);
146  return success();
147  }
148 };
149 
150 /// Converts `vector.create_mask` ops that match the size of an SVE predicate
151 /// to the `whilelt` intrinsic. This produces more canonical codegen than the
152 /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
153 /// for more details. Note that we can't use (the more general) active.lane.mask
154 /// as its semantics don't neatly map on to `vector.create_mask`, as it does an
155 /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
156 /// `n` is zero (whereas `create_mask` just returns an all-false mask).
157 struct CreateMaskOpLowering
158  : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
160 
161  LogicalResult
162  matchAndRewrite(vector::CreateMaskOp createMaskOp,
163  vector::CreateMaskOp::Adaptor adaptor,
164  ConversionPatternRewriter &rewriter) const override {
165  auto maskType = createMaskOp.getVectorType();
166  if (maskType.getRank() != 1 || !maskType.isScalable())
167  return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
168 
169  // TODO: Support masks which are multiples of SVE predicates.
170  auto maskBaseSize = maskType.getDimSize(0);
171  if (maskBaseSize < 2 || maskBaseSize > 16 ||
172  !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
173  return rewriter.notifyMatchFailure(createMaskOp,
174  "not SVE predicate-sized");
175 
176  auto loc = createMaskOp.getLoc();
177  auto zero = LLVM::ZeroOp::create(rewriter, loc, rewriter.getI64Type());
178  rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
179  adaptor.getOperands()[0]);
180  return success();
181  }
182 };
183 
184 } // namespace
185 
186 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
188  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
189  // Populate conversion patterns
190 
191  // clang-format off
192  patterns.add<ConvertFromSvboolOpLowering,
193  ConvertToSvboolOpLowering,
195  PselOpLowering,
209  ZipX2OpLowering,
210  ZipX4OpLowering,
211  SdotOpLowering>(converter);
212  // Add vector.create_mask conversion with a high benefit as it produces much
213  // nicer code than the generic lowering.
214  patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
215  // clang-format on
216 }
217 
219  LLVMConversionTarget &target) {
220  // clang-format off
221  target.addLegalOp<BfmmlaOp,
222  ConvertFromSvboolIntrOp,
223  ConvertToSvboolIntrOp,
224  DupQLaneIntrOp,
225  PselIntrOp,
226  ScalableMaskedAddFIntrOp,
227  ScalableMaskedAddIIntrOp,
228  ScalableMaskedDivFIntrOp,
229  ScalableMaskedMulFIntrOp,
230  ScalableMaskedMulIIntrOp,
231  ScalableMaskedSDivIIntrOp,
232  ScalableMaskedSubFIntrOp,
233  ScalableMaskedSubIIntrOp,
234  ScalableMaskedUDivIIntrOp,
235  SmmlaIntrOp,
236  UdotIntrOp,
237  UmmlaIntrOp,
238  UsmmlaIntrOp,
239  WhileLTIntrOp,
240  ZipX2IntrOp,
241  ZipX4IntrOp,
242  SdotIntrOp>();
243  target.addIllegalOp<ConvertFromSvboolOp,
244  ConvertToSvboolOp,
245  DupQLaneOp,
246  PselOp,
247  ScalableMaskedAddFOp,
248  ScalableMaskedAddIOp,
249  ScalableMaskedDivFOp,
250  ScalableMaskedMulFOp,
251  ScalableMaskedMulIOp,
252  ScalableMaskedSDivIOp,
253  ScalableMaskedSubFOp,
254  ScalableMaskedSubIOp,
255  ScalableMaskedUDivIOp,
256  SmmlaOp,
257  UdotOp,
258  UmmlaOp,
259  UsmmlaOp,
260  ZipX2Op,
261  ZipX4Op,
262  SdotOp>();
263  // clang-format on
264 }
OneToOneConvertToLLVMPattern< UmmlaOp, UmmlaIntrOp > UmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubIOp, ScalableMaskedSubIIntrOp > ScalableMaskedSubIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSDivIOp, ScalableMaskedSDivIIntrOp > ScalableMaskedSDivIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddIOp, ScalableMaskedAddIIntrOp > ScalableMaskedAddIOpLowering
OneToOneConvertToLLVMPattern< UsmmlaOp, UsmmlaIntrOp > UsmmlaOpLowering
OneToOneConvertToLLVMPattern< SmmlaOp, SmmlaIntrOp > SmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubFOp, ScalableMaskedSubFIntrOp > ScalableMaskedSubFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedUDivIOp, ScalableMaskedUDivIIntrOp > ScalableMaskedUDivIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedDivFOp, ScalableMaskedDivFIntrOp > ScalableMaskedDivFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulFOp, ScalableMaskedMulFIntrOp > ScalableMaskedMulFOpLowering
OneToOneConvertToLLVMPattern< UdotOp, UdotIntrOp > UdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddFOp, ScalableMaskedAddFIntrOp > ScalableMaskedAddFOpLowering
OneToOneConvertToLLVMPattern< DupQLaneOp, DupQLaneIntrOp > DupQLaneLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulIOp, ScalableMaskedMulIIntrOp > ScalableMaskedMulIOpLowering
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
IntegerType getI1Type()
Definition: Builders.cpp:52
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:199
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:205
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:292
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:129
This provides public APIs that all operations should have.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:286
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
Definition: BuiltinTypes.h:320
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateArmSVELegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.