MLIR  21.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::arm_sve;
22 
28  OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
29  ScalableMaskedAddIIntrOp>;
31  OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
32  ScalableMaskedAddFIntrOp>;
34  OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
35  ScalableMaskedSubIIntrOp>;
37  OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
38  ScalableMaskedSubFIntrOp>;
40  OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
41  ScalableMaskedMulIIntrOp>;
43  OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
44  ScalableMaskedMulFIntrOp>;
46  OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
47  ScalableMaskedSDivIIntrOp>;
49  OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
50  ScalableMaskedUDivIIntrOp>;
52  OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
53  ScalableMaskedDivFIntrOp>;
54 
55 namespace {
56 
57 /// Unrolls a conversion to/from equivalent vector types, to allow using a
58 /// conversion intrinsic that only supports 1-D vector types.
59 ///
60 /// Example:
61 /// ```
62 /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
63 /// ```
64 /// is rewritten into:
65 /// ```
66 /// %cst = arith.constant dense<false> : vector<2x[16]xi1>
67 /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
68 /// %2 = "arm_sve.intr.convert.to.svbool"(%1)
69 /// : (vector<[4]xi1>) -> vector<[16]xi1>
70 /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
71 /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
72 /// %5 = "arm_sve.intr.convert.to.svbool"(%4)
73 /// : (vector<[4]xi1>) -> vector<[16]xi1>
74 /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
75 /// ```
76 template <typename Op, typename IntrOp>
77 struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
79 
80  LogicalResult
81  matchAndRewrite(Op convertOp, typename Op::Adaptor,
82  ConversionPatternRewriter &rewriter) const override {
83  auto loc = convertOp.getLoc();
84 
85  auto source = convertOp.getSource();
86  VectorType sourceType = source.getType();
87  VectorType resultType = convertOp.getResult().getType();
88 
89  Value result = rewriter.create<arith::ConstantOp>(
90  loc, resultType, rewriter.getZeroAttr(resultType));
91 
92  // We want to iterate over the input vector in steps of the trailing
93  // dimension. So this creates tile shape where all leading dimensions are 1,
94  // and the trailing dimension step is the size of the dimension.
95  SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
96  tileShape.back() = sourceType.getShape().back();
97 
98  // Iterate over all scalable mask/predicate slices of the source vector.
99  for (SmallVector<int64_t> index :
100  StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
101  auto extractOrInsertPosition = ArrayRef(index).drop_back();
102  auto sourceVector = rewriter.create<vector::ExtractOp>(
103  loc, source, extractOrInsertPosition);
104  VectorType convertedType =
105  VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
106  .setDim(0, resultType.getShape().back());
107  auto convertedVector =
108  rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
109  result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
110  extractOrInsertPosition);
111  }
112 
113  rewriter.replaceOp(convertOp, result);
114  return success();
115  }
116 };
117 
118 using ConvertToSvboolOpLowering =
119  SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
120 
121 using ConvertFromSvboolOpLowering =
122  SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
123 
126 
127 /// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
128 /// but first input (P1) and result predicates need conversion to/from svbool.
129 struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
131 
132  LogicalResult
133  matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
134  ConversionPatternRewriter &rewriter) const override {
135  auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
136  auto loc = pselOp.getLoc();
137  auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
138  adaptor.getP1());
139  auto indexI32 = rewriter.create<arith::IndexCastOp>(
140  loc, rewriter.getI32Type(), pselOp.getIndex());
141  auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
142  pselOp.getP2(), indexI32);
143  rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
144  pselOp, adaptor.getP1().getType(), pselIntr);
145  return success();
146  }
147 };
148 
149 /// Converts `vector.create_mask` ops that match the size of an SVE predicate
150 /// to the `whilelt` intrinsic. This produces more canonical codegen than the
151 /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
152 /// for more details. Note that we can't use (the more general) active.lane.mask
153 /// as its semantics don't neatly map on to `vector.create_mask`, as it does an
154 /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
155 /// `n` is zero (whereas `create_mask` just returns an all-false mask).
156 struct CreateMaskOpLowering
157  : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
159 
160  LogicalResult
161  matchAndRewrite(vector::CreateMaskOp createMaskOp,
162  vector::CreateMaskOp::Adaptor adaptor,
163  ConversionPatternRewriter &rewriter) const override {
164  auto maskType = createMaskOp.getVectorType();
165  if (maskType.getRank() != 1 || !maskType.isScalable())
166  return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
167 
168  // TODO: Support masks which are multiples of SVE predicates.
169  auto maskBaseSize = maskType.getDimSize(0);
170  if (maskBaseSize < 2 || maskBaseSize > 16 ||
171  !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
172  return rewriter.notifyMatchFailure(createMaskOp,
173  "not SVE predicate-sized");
174 
175  auto loc = createMaskOp.getLoc();
176  auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
177  rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
178  adaptor.getOperands()[0]);
179  return success();
180  }
181 };
182 
183 } // namespace
184 
185 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
187  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
188  // Populate conversion patterns
189 
190  // clang-format off
204  ConvertToSvboolOpLowering,
205  ConvertFromSvboolOpLowering,
206  ZipX2OpLowering,
207  ZipX4OpLowering,
208  PselOpLowering>(converter);
209  // Add vector.create_mask conversion with a high benefit as it produces much
210  // nicer code than the generic lowering.
211  patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
212  // clang-format on
213 }
214 
216  LLVMConversionTarget &target) {
217  // clang-format off
218  target.addLegalOp<SdotIntrOp,
219  SmmlaIntrOp,
220  UdotIntrOp,
221  UmmlaIntrOp,
222  ScalableMaskedAddIIntrOp,
223  ScalableMaskedAddFIntrOp,
224  ScalableMaskedSubIIntrOp,
225  ScalableMaskedSubFIntrOp,
226  ScalableMaskedMulIIntrOp,
227  ScalableMaskedMulFIntrOp,
228  ScalableMaskedSDivIIntrOp,
229  ScalableMaskedUDivIIntrOp,
230  ScalableMaskedDivFIntrOp,
231  ConvertToSvboolIntrOp,
232  ConvertFromSvboolIntrOp,
233  ZipX2IntrOp,
234  ZipX4IntrOp,
235  PselIntrOp,
236  WhileLTIntrOp>();
237  target.addIllegalOp<SdotOp,
238  SmmlaOp,
239  UdotOp,
240  UmmlaOp,
241  ScalableMaskedAddIOp,
242  ScalableMaskedAddFOp,
243  ScalableMaskedSubIOp,
244  ScalableMaskedSubFOp,
245  ScalableMaskedMulIOp,
246  ScalableMaskedMulFOp,
247  ScalableMaskedSDivIOp,
248  ScalableMaskedUDivIOp,
249  ScalableMaskedDivFOp,
250  ConvertToSvboolOp,
251  ConvertFromSvboolOp,
252  ZipX2Op,
253  ZipX4Op>();
254  // clang-format on
255 }
OneToOneConvertToLLVMPattern< UmmlaOp, UmmlaIntrOp > UmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubIOp, ScalableMaskedSubIIntrOp > ScalableMaskedSubIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSDivIOp, ScalableMaskedSDivIIntrOp > ScalableMaskedSDivIOpLowering
OneToOneConvertToLLVMPattern< SdotOp, SdotIntrOp > SdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddIOp, ScalableMaskedAddIIntrOp > ScalableMaskedAddIOpLowering
OneToOneConvertToLLVMPattern< SmmlaOp, SmmlaIntrOp > SmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubFOp, ScalableMaskedSubFIntrOp > ScalableMaskedSubFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedUDivIOp, ScalableMaskedUDivIIntrOp > ScalableMaskedUDivIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedDivFOp, ScalableMaskedDivFIntrOp > ScalableMaskedDivFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulFOp, ScalableMaskedMulFIntrOp > ScalableMaskedMulFOpLowering
OneToOneConvertToLLVMPattern< UdotOp, UdotIntrOp > UdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddFOp, ScalableMaskedAddFIntrOp > ScalableMaskedAddFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulIOp, ScalableMaskedMulIIntrOp > ScalableMaskedMulIOpLowering
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
IntegerType getI1Type()
Definition: Builders.cpp:53
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:148
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:160
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This provides public APIs that all operations should have.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:270
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
Definition: BuiltinTypes.h:304
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateArmSVELegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.