MLIR 22.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17
18using namespace mlir;
19using namespace mlir::arm_sve;
20
29 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
30 ScalableMaskedAddIIntrOp>;
32 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
33 ScalableMaskedAddFIntrOp>;
35 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
36 ScalableMaskedSubIIntrOp>;
38 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
39 ScalableMaskedSubFIntrOp>;
41 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
42 ScalableMaskedMulIIntrOp>;
44 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
45 ScalableMaskedMulFIntrOp>;
47 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
48 ScalableMaskedSDivIIntrOp>;
50 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
51 ScalableMaskedUDivIIntrOp>;
53 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
54 ScalableMaskedDivFIntrOp>;
55
56namespace {
57
58/// Unrolls a conversion to/from equivalent vector types, to allow using a
59/// conversion intrinsic that only supports 1-D vector types.
60///
61/// Example:
62/// ```
63/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
64/// ```
65/// is rewritten into:
66/// ```
67/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
68/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
69/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
70/// : (vector<[4]xi1>) -> vector<[16]xi1>
71/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
72/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
73/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
74/// : (vector<[4]xi1>) -> vector<[16]xi1>
75/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
76/// ```
77template <typename Op, typename IntrOp>
78struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
80
81 LogicalResult
82 matchAndRewrite(Op convertOp, typename Op::Adaptor,
83 ConversionPatternRewriter &rewriter) const override {
84 auto loc = convertOp.getLoc();
85
86 auto source = convertOp.getSource();
87 VectorType sourceType = source.getType();
88 VectorType resultType = convertOp.getResult().getType();
89
90 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
91 rewriter.getZeroAttr(resultType));
92
93 // We want to iterate over the input vector in steps of the trailing
94 // dimension. So this creates tile shape where all leading dimensions are 1,
95 // and the trailing dimension step is the size of the dimension.
96 SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
97 tileShape.back() = sourceType.getShape().back();
98
99 // Iterate over all scalable mask/predicate slices of the source vector.
101 StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
102 auto extractOrInsertPosition = ArrayRef(index).drop_back();
103 auto sourceVector = vector::ExtractOp::create(rewriter, loc, source,
104 extractOrInsertPosition);
105 VectorType convertedType =
106 VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
107 .setDim(0, resultType.getShape().back());
108 auto convertedVector =
109 IntrOp::create(rewriter, loc, TypeRange{convertedType}, sourceVector);
110 result = vector::InsertOp::create(rewriter, loc, convertedVector, result,
111 extractOrInsertPosition);
112 }
113
114 rewriter.replaceOp(convertOp, result);
115 return success();
116 }
117};
118
119using ConvertToSvboolOpLowering =
120 SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
121
122using ConvertFromSvboolOpLowering =
123 SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
124
127
128/// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
129/// but first input (P1) and result predicates need conversion to/from svbool.
130struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
132
133 LogicalResult
134 matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
135 ConversionPatternRewriter &rewriter) const override {
136 auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
137 auto loc = pselOp.getLoc();
138 auto svboolP1 = ConvertToSvboolIntrOp::create(rewriter, loc, svboolType,
139 adaptor.getP1());
140 auto indexI32 = arith::IndexCastOp::create(
141 rewriter, loc, rewriter.getI32Type(), pselOp.getIndex());
142 auto pselIntr = PselIntrOp::create(rewriter, loc, svboolType, svboolP1,
143 pselOp.getP2(), indexI32);
144 rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
145 pselOp, adaptor.getP1().getType(), pselIntr);
146 return success();
147 }
148};
149
150/// Converts `vector.create_mask` ops that match the size of an SVE predicate
151/// to the `whilelt` intrinsic. This produces more canonical codegen than the
152/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
153/// for more details. Note that we can't use (the more general) active.lane.mask
154/// as its semantics don't neatly map on to `vector.create_mask`, as it does an
155/// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
156/// `n` is zero (whereas `create_mask` just returns an all-false mask).
157struct CreateMaskOpLowering
158 : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
160
161 LogicalResult
162 matchAndRewrite(vector::CreateMaskOp createMaskOp,
163 vector::CreateMaskOp::Adaptor adaptor,
164 ConversionPatternRewriter &rewriter) const override {
165 auto maskType = createMaskOp.getVectorType();
166 if (maskType.getRank() != 1 || !maskType.isScalable())
167 return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
168
169 // TODO: Support masks which are multiples of SVE predicates.
170 auto maskBaseSize = maskType.getDimSize(0);
171 if (maskBaseSize < 2 || maskBaseSize > 16 ||
172 !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
173 return rewriter.notifyMatchFailure(createMaskOp,
174 "not SVE predicate-sized");
175
176 auto loc = createMaskOp.getLoc();
177 auto zero = LLVM::ZeroOp::create(rewriter, loc, rewriter.getI64Type());
178 rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
179 adaptor.getOperands()[0]);
180 return success();
181 }
182};
183
184} // namespace
185
186/// Populate the given list with patterns that convert from ArmSVE to LLVM.
188 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
189 // Populate conversion patterns
190
191 // clang-format off
192 patterns.add<ConvertFromSvboolOpLowering,
193 ConvertToSvboolOpLowering,
195 PselOpLowering,
209 ZipX2OpLowering,
210 ZipX4OpLowering,
211 SdotOpLowering>(converter);
212 // Add vector.create_mask conversion with a high benefit as it produces much
213 // nicer code than the generic lowering.
214 patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
215 // clang-format on
216}
217
220 // clang-format off
221 target.addLegalOp<BfmmlaOp,
222 ConvertFromSvboolIntrOp,
223 ConvertToSvboolIntrOp,
224 DupQLaneIntrOp,
225 PselIntrOp,
226 ScalableMaskedAddFIntrOp,
227 ScalableMaskedAddIIntrOp,
228 ScalableMaskedDivFIntrOp,
229 ScalableMaskedMulFIntrOp,
230 ScalableMaskedMulIIntrOp,
231 ScalableMaskedSDivIIntrOp,
232 ScalableMaskedSubFIntrOp,
233 ScalableMaskedSubIIntrOp,
234 ScalableMaskedUDivIIntrOp,
235 SmmlaIntrOp,
236 UdotIntrOp,
237 UmmlaIntrOp,
238 UsmmlaIntrOp,
239 WhileLTIntrOp,
240 ZipX2IntrOp,
241 ZipX4IntrOp,
242 SdotIntrOp>();
243 target.addIllegalOp<ConvertFromSvboolOp,
244 ConvertToSvboolOp,
245 DupQLaneOp,
246 PselOp,
247 ScalableMaskedAddFOp,
248 ScalableMaskedAddIOp,
249 ScalableMaskedDivFOp,
250 ScalableMaskedMulFOp,
251 ScalableMaskedMulIOp,
252 ScalableMaskedSDivIOp,
253 ScalableMaskedSubFOp,
254 ScalableMaskedSubIOp,
255 ScalableMaskedUDivIOp,
256 SmmlaOp,
257 UdotOp,
258 UmmlaOp,
259 UsmmlaOp,
260 ZipX2Op,
261 ZipX4Op,
262 SdotOp>();
263 // clang-format on
264}
return success()
OneToOneConvertToLLVMPattern< ScalableMaskedMulIOp, ScalableMaskedMulIIntrOp > ScalableMaskedMulIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddFOp, ScalableMaskedAddFIntrOp > ScalableMaskedAddFOpLowering
OneToOneConvertToLLVMPattern< UmmlaOp, UmmlaIntrOp > UmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulFOp, ScalableMaskedMulFIntrOp > ScalableMaskedMulFOpLowering
OneToOneConvertToLLVMPattern< SdotOp, SdotIntrOp > SdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedUDivIOp, ScalableMaskedUDivIIntrOp > ScalableMaskedUDivIOpLowering
OneToOneConvertToLLVMPattern< UsmmlaOp, UsmmlaIntrOp > UsmmlaOpLowering
OneToOneConvertToLLVMPattern< SmmlaOp, SmmlaIntrOp > SmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedDivFOp, ScalableMaskedDivFIntrOp > ScalableMaskedDivFOpLowering
OneToOneConvertToLLVMPattern< DupQLaneOp, DupQLaneIntrOp > DupQLaneLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSDivIOp, ScalableMaskedSDivIIntrOp > ScalableMaskedSDivIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddIOp, ScalableMaskedAddIIntrOp > ScalableMaskedAddIOpLowering
OneToOneConvertToLLVMPattern< UdotOp, UdotIntrOp > UdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubIOp, ScalableMaskedSubIIntrOp > ScalableMaskedSubIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubFOp, ScalableMaskedSubFIntrOp > ScalableMaskedSubFOpLowering
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:207
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:213
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition Pattern.h:298
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
This is a builder type that keeps local references to arguments.
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
const FrozenRewritePatternSet & patterns
void populateArmSVELegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.