MLIR  20.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::arm_sve;
22 
23 template <typename OpTy>
24 class ForwardOperands : public OpConversionPattern<OpTy> {
26 
27  LogicalResult
28  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
29  ConversionPatternRewriter &rewriter) const final {
30  if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
31  return rewriter.notifyMatchFailure(op, "operand types already match");
32 
33  rewriter.modifyOpInPlace(op,
34  [&]() { op->setOperands(adaptor.getOperands()); });
35  return success();
36  }
37 };
38 
44  OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
45  ScalableMaskedAddIIntrOp>;
47  OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
48  ScalableMaskedAddFIntrOp>;
50  OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
51  ScalableMaskedSubIIntrOp>;
53  OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
54  ScalableMaskedSubFIntrOp>;
56  OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
57  ScalableMaskedMulIIntrOp>;
59  OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
60  ScalableMaskedMulFIntrOp>;
62  OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
63  ScalableMaskedSDivIIntrOp>;
65  OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
66  ScalableMaskedUDivIIntrOp>;
68  OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
69  ScalableMaskedDivFIntrOp>;
70 
71 namespace {
72 
73 /// Unrolls a conversion to/from equivalent vector types, to allow using a
74 /// conversion intrinsic that only supports 1-D vector types.
75 ///
76 /// Example:
77 /// ```
78 /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
79 /// ```
80 /// is rewritten into:
81 /// ```
82 /// %cst = arith.constant dense<false> : vector<2x[16]xi1>
83 /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
84 /// %2 = "arm_sve.intr.convert.to.svbool"(%1)
85 /// : (vector<[4]xi1>) -> vector<[16]xi1>
86 /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
87 /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
88 /// %5 = "arm_sve.intr.convert.to.svbool"(%4)
89 /// : (vector<[4]xi1>) -> vector<[16]xi1>
90 /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
91 /// ```
92 template <typename Op, typename IntrOp>
93 struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
95 
96  LogicalResult
97  matchAndRewrite(Op convertOp, typename Op::Adaptor,
98  ConversionPatternRewriter &rewriter) const override {
99  auto loc = convertOp.getLoc();
100 
101  auto source = convertOp.getSource();
102  VectorType sourceType = source.getType();
103  VectorType resultType = convertOp.getResult().getType();
104 
105  Value result = rewriter.create<arith::ConstantOp>(
106  loc, resultType, rewriter.getZeroAttr(resultType));
107 
108  // We want to iterate over the input vector in steps of the trailing
109  // dimension. So this creates tile shape where all leading dimensions are 1,
110  // and the trailing dimension step is the size of the dimension.
111  SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
112  tileShape.back() = sourceType.getShape().back();
113 
114  // Iterate over all scalable mask/predicate slices of the source vector.
115  for (SmallVector<int64_t> index :
116  StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
117  auto extractOrInsertPosition = ArrayRef(index).drop_back();
118  auto sourceVector = rewriter.create<vector::ExtractOp>(
119  loc, source, extractOrInsertPosition);
120  VectorType convertedType =
121  VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
122  .setDim(0, resultType.getShape().back());
123  auto convertedVector =
124  rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
125  result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
126  extractOrInsertPosition);
127  }
128 
129  rewriter.replaceOp(convertOp, result);
130  return success();
131  }
132 };
133 
134 using ConvertToSvboolOpLowering =
135  SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
136 
137 using ConvertFromSvboolOpLowering =
138  SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
139 
142 
143 /// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
144 /// but first input (P1) and result predicates need conversion to/from svbool.
145 struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
147 
148  LogicalResult
149  matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
150  ConversionPatternRewriter &rewriter) const override {
151  auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
152  auto loc = pselOp.getLoc();
153  auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
154  adaptor.getP1());
155  auto indexI32 = rewriter.create<arith::IndexCastOp>(
156  loc, rewriter.getI32Type(), pselOp.getIndex());
157  auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
158  pselOp.getP2(), indexI32);
159  rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
160  pselOp, adaptor.getP1().getType(), pselIntr);
161  return success();
162  }
163 };
164 
165 /// Converts `vector.create_mask` ops that match the size of an SVE predicate
166 /// to the `whilelt` intrinsic. This produces more canonical codegen than the
167 /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
168 /// for more details. Note that we can't use (the more general) active.lane.mask
169 /// as its semantics don't neatly map on to `vector.create_mask`, as it does an
170 /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
171 /// `n` is zero (whereas `create_mask` just returns an all-false mask).
172 struct CreateMaskOpLowering
173  : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
175 
176  LogicalResult
177  matchAndRewrite(vector::CreateMaskOp createMaskOp,
178  vector::CreateMaskOp::Adaptor adaptor,
179  ConversionPatternRewriter &rewriter) const override {
180  auto maskType = createMaskOp.getVectorType();
181  if (maskType.getRank() != 1 || !maskType.isScalable())
182  return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
183 
184  // TODO: Support masks which are multiples of SVE predicates.
185  auto maskBaseSize = maskType.getDimSize(0);
186  if (maskBaseSize < 2 || maskBaseSize > 16 ||
187  !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
188  return rewriter.notifyMatchFailure(createMaskOp,
189  "not SVE predicate-sized");
190 
191  auto loc = createMaskOp.getLoc();
192  auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
193  rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
194  adaptor.getOperands()[0]);
195  return success();
196  }
197 };
198 
199 } // namespace
200 
201 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
203  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
204  // Populate conversion patterns
205 
206  // clang-format off
210  &converter.getContext());
211  patterns.add<SdotOpLowering,
224  ConvertToSvboolOpLowering,
225  ConvertFromSvboolOpLowering,
226  ZipX2OpLowering,
227  ZipX4OpLowering,
228  PselOpLowering>(converter);
229  // Add vector.create_mask conversion with a high benefit as it produces much
230  // nicer code than the generic lowering.
231  patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
232  // clang-format on
233 }
234 
236  LLVMConversionTarget &target) {
237  // clang-format off
238  target.addLegalOp<SdotIntrOp,
239  SmmlaIntrOp,
240  UdotIntrOp,
241  UmmlaIntrOp,
242  ScalableMaskedAddIIntrOp,
243  ScalableMaskedAddFIntrOp,
244  ScalableMaskedSubIIntrOp,
245  ScalableMaskedSubFIntrOp,
246  ScalableMaskedMulIIntrOp,
247  ScalableMaskedMulFIntrOp,
248  ScalableMaskedSDivIIntrOp,
249  ScalableMaskedUDivIIntrOp,
250  ScalableMaskedDivFIntrOp,
251  ConvertToSvboolIntrOp,
252  ConvertFromSvboolIntrOp,
253  ZipX2IntrOp,
254  ZipX4IntrOp,
255  PselIntrOp,
256  WhileLTIntrOp>();
257  target.addIllegalOp<SdotOp,
258  SmmlaOp,
259  UdotOp,
260  UmmlaOp,
261  ScalableMaskedAddIOp,
262  ScalableMaskedAddFOp,
263  ScalableMaskedSubIOp,
264  ScalableMaskedSubFOp,
265  ScalableMaskedMulIOp,
266  ScalableMaskedMulFOp,
267  ScalableMaskedSDivIOp,
268  ScalableMaskedUDivIOp,
269  ScalableMaskedDivFOp,
270  ConvertToSvboolOp,
271  ConvertFromSvboolOp,
272  ZipX2Op,
273  ZipX4Op>();
274  // clang-format on
275 }
OneToOneConvertToLLVMPattern< UmmlaOp, UmmlaIntrOp > UmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubIOp, ScalableMaskedSubIIntrOp > ScalableMaskedSubIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSDivIOp, ScalableMaskedSDivIIntrOp > ScalableMaskedSDivIOpLowering
OneToOneConvertToLLVMPattern< SdotOp, SdotIntrOp > SdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddIOp, ScalableMaskedAddIIntrOp > ScalableMaskedAddIOpLowering
OneToOneConvertToLLVMPattern< SmmlaOp, SmmlaIntrOp > SmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubFOp, ScalableMaskedSubFIntrOp > ScalableMaskedSubFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedUDivIOp, ScalableMaskedUDivIIntrOp > ScalableMaskedUDivIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedDivFOp, ScalableMaskedDivFIntrOp > ScalableMaskedDivFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulFOp, ScalableMaskedMulFIntrOp > ScalableMaskedMulFOpLowering
OneToOneConvertToLLVMPattern< UdotOp, UdotIntrOp > UdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddFOp, ScalableMaskedAddFIntrOp > ScalableMaskedAddFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulIOp, ScalableMaskedMulIIntrOp > ScalableMaskedMulIOpLowering
IntegerType getI64Type()
Definition: Builders.cpp:89
IntegerType getI32Type()
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:335
IntegerType getI1Type()
Definition: Builders.cpp:77
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:147
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
MLIRContext & getContext() const
Returns the MLIR context.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:198
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
type_range getTypes() const
Definition: ValueRange.cpp:26
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:311
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
Definition: BuiltinTypes.h:345
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
void populateArmSVELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...