MLIR  21.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::arm_sve;
22 
31  OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
32  ScalableMaskedAddIIntrOp>;
34  OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
35  ScalableMaskedAddFIntrOp>;
37  OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
38  ScalableMaskedSubIIntrOp>;
40  OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
41  ScalableMaskedSubFIntrOp>;
43  OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
44  ScalableMaskedMulIIntrOp>;
46  OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
47  ScalableMaskedMulFIntrOp>;
49  OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
50  ScalableMaskedSDivIIntrOp>;
52  OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
53  ScalableMaskedUDivIIntrOp>;
55  OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
56  ScalableMaskedDivFIntrOp>;
57 
58 namespace {
59 
60 /// Unrolls a conversion to/from equivalent vector types, to allow using a
61 /// conversion intrinsic that only supports 1-D vector types.
62 ///
63 /// Example:
64 /// ```
65 /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
66 /// ```
67 /// is rewritten into:
68 /// ```
69 /// %cst = arith.constant dense<false> : vector<2x[16]xi1>
70 /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
71 /// %2 = "arm_sve.intr.convert.to.svbool"(%1)
72 /// : (vector<[4]xi1>) -> vector<[16]xi1>
73 /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
74 /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
75 /// %5 = "arm_sve.intr.convert.to.svbool"(%4)
76 /// : (vector<[4]xi1>) -> vector<[16]xi1>
77 /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
78 /// ```
79 template <typename Op, typename IntrOp>
80 struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
82 
83  LogicalResult
84  matchAndRewrite(Op convertOp, typename Op::Adaptor,
85  ConversionPatternRewriter &rewriter) const override {
86  auto loc = convertOp.getLoc();
87 
88  auto source = convertOp.getSource();
89  VectorType sourceType = source.getType();
90  VectorType resultType = convertOp.getResult().getType();
91 
92  Value result = rewriter.create<arith::ConstantOp>(
93  loc, resultType, rewriter.getZeroAttr(resultType));
94 
95  // We want to iterate over the input vector in steps of the trailing
96  // dimension. So this creates tile shape where all leading dimensions are 1,
97  // and the trailing dimension step is the size of the dimension.
98  SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
99  tileShape.back() = sourceType.getShape().back();
100 
101  // Iterate over all scalable mask/predicate slices of the source vector.
102  for (SmallVector<int64_t> index :
103  StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
104  auto extractOrInsertPosition = ArrayRef(index).drop_back();
105  auto sourceVector = rewriter.create<vector::ExtractOp>(
106  loc, source, extractOrInsertPosition);
107  VectorType convertedType =
108  VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
109  .setDim(0, resultType.getShape().back());
110  auto convertedVector =
111  rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
112  result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
113  extractOrInsertPosition);
114  }
115 
116  rewriter.replaceOp(convertOp, result);
117  return success();
118  }
119 };
120 
121 using ConvertToSvboolOpLowering =
122  SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
123 
124 using ConvertFromSvboolOpLowering =
125  SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
126 
129 
130 /// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
131 /// but first input (P1) and result predicates need conversion to/from svbool.
132 struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
134 
135  LogicalResult
136  matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
137  ConversionPatternRewriter &rewriter) const override {
138  auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
139  auto loc = pselOp.getLoc();
140  auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
141  adaptor.getP1());
142  auto indexI32 = rewriter.create<arith::IndexCastOp>(
143  loc, rewriter.getI32Type(), pselOp.getIndex());
144  auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
145  pselOp.getP2(), indexI32);
146  rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
147  pselOp, adaptor.getP1().getType(), pselIntr);
148  return success();
149  }
150 };
151 
152 /// Converts `vector.create_mask` ops that match the size of an SVE predicate
153 /// to the `whilelt` intrinsic. This produces more canonical codegen than the
154 /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
155 /// for more details. Note that we can't use (the more general) active.lane.mask
156 /// as its semantics don't neatly map on to `vector.create_mask`, as it does an
157 /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
158 /// `n` is zero (whereas `create_mask` just returns an all-false mask).
159 struct CreateMaskOpLowering
160  : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
162 
163  LogicalResult
164  matchAndRewrite(vector::CreateMaskOp createMaskOp,
165  vector::CreateMaskOp::Adaptor adaptor,
166  ConversionPatternRewriter &rewriter) const override {
167  auto maskType = createMaskOp.getVectorType();
168  if (maskType.getRank() != 1 || !maskType.isScalable())
169  return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
170 
171  // TODO: Support masks which are multiples of SVE predicates.
172  auto maskBaseSize = maskType.getDimSize(0);
173  if (maskBaseSize < 2 || maskBaseSize > 16 ||
174  !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
175  return rewriter.notifyMatchFailure(createMaskOp,
176  "not SVE predicate-sized");
177 
178  auto loc = createMaskOp.getLoc();
179  auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
180  rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
181  adaptor.getOperands()[0]);
182  return success();
183  }
184 };
185 
186 } // namespace
187 
188 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
190  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
191  // Populate conversion patterns
192 
193  // clang-format off
194  patterns.add<ConvertFromSvboolOpLowering,
195  ConvertToSvboolOpLowering,
197  PselOpLowering,
211  ZipX2OpLowering,
212  ZipX4OpLowering,
213  SdotOpLowering>(converter);
214  // Add vector.create_mask conversion with a high benefit as it produces much
215  // nicer code than the generic lowering.
216  patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
217  // clang-format on
218 }
219 
221  LLVMConversionTarget &target) {
222  // clang-format off
223  target.addLegalOp<ConvertFromSvboolIntrOp,
224  ConvertToSvboolIntrOp,
225  DupQLaneIntrOp,
226  PselIntrOp,
227  ScalableMaskedAddFIntrOp,
228  ScalableMaskedAddIIntrOp,
229  ScalableMaskedDivFIntrOp,
230  ScalableMaskedMulFIntrOp,
231  ScalableMaskedMulIIntrOp,
232  ScalableMaskedSDivIIntrOp,
233  ScalableMaskedSubFIntrOp,
234  ScalableMaskedSubIIntrOp,
235  ScalableMaskedUDivIIntrOp,
236  SmmlaIntrOp,
237  UdotIntrOp,
238  UmmlaIntrOp,
239  UsmmlaIntrOp,
240  WhileLTIntrOp,
241  ZipX2IntrOp,
242  ZipX4IntrOp,
243  SdotIntrOp>();
244  target.addIllegalOp<ConvertFromSvboolOp,
245  ConvertToSvboolOp,
246  DupQLaneOp,
247  PselOp,
248  ScalableMaskedAddFOp,
249  ScalableMaskedAddIOp,
250  ScalableMaskedDivFOp,
251  ScalableMaskedMulFOp,
252  ScalableMaskedMulIOp,
253  ScalableMaskedSDivIOp,
254  ScalableMaskedSubFOp,
255  ScalableMaskedSubIOp,
256  ScalableMaskedUDivIOp,
257  SmmlaOp,
258  UdotOp,
259  UmmlaOp,
260  UsmmlaOp,
261  ZipX2Op,
262  ZipX4Op,
263  SdotOp>();
264  // clang-format on
265 }
OneToOneConvertToLLVMPattern< UmmlaOp, UmmlaIntrOp > UmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubIOp, ScalableMaskedSubIIntrOp > ScalableMaskedSubIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSDivIOp, ScalableMaskedSDivIIntrOp > ScalableMaskedSDivIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddIOp, ScalableMaskedAddIIntrOp > ScalableMaskedAddIOpLowering
OneToOneConvertToLLVMPattern< UsmmlaOp, UsmmlaIntrOp > UsmmlaOpLowering
OneToOneConvertToLLVMPattern< SmmlaOp, SmmlaIntrOp > SmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubFOp, ScalableMaskedSubFIntrOp > ScalableMaskedSubFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedUDivIOp, ScalableMaskedUDivIIntrOp > ScalableMaskedUDivIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedDivFOp, ScalableMaskedDivFIntrOp > ScalableMaskedDivFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulFOp, ScalableMaskedMulFIntrOp > ScalableMaskedMulFOpLowering
OneToOneConvertToLLVMPattern< UdotOp, UdotIntrOp > UdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddFOp, ScalableMaskedAddFIntrOp > ScalableMaskedAddFOpLowering
OneToOneConvertToLLVMPattern< DupQLaneOp, DupQLaneIntrOp > DupQLaneLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulIOp, ScalableMaskedMulIIntrOp > ScalableMaskedMulIOpLowering
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
IntegerType getI1Type()
Definition: Builders.cpp:53
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:155
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:161
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:207
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This provides public APIs that all operations should have.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:270
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
Definition: BuiltinTypes.h:304
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateArmSVELegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.