MLIR 23.0.0git
ConvertVectorToLLVM.cpp
Go to the documentation of this file.
1//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
31#include "llvm/ADT/APFloat.h"
32#include "llvm/IR/LLVMContext.h"
33#include "llvm/Support/Casting.h"
34
35#include <optional>
36
37using namespace mlir;
38using namespace mlir::vector;
39
40// Helper that picks the proper sequence for inserting.
41static Value insertOne(ConversionPatternRewriter &rewriter,
42 const LLVMTypeConverter &typeConverter, Location loc,
43 Value val1, Value val2, Type llvmType, int64_t rank,
44 int64_t pos) {
45 assert(rank > 0 && "0-D vector corner case should have been handled already");
46 if (rank == 1) {
47 auto idxType = rewriter.getIndexType();
48 auto constant = LLVM::ConstantOp::create(
49 rewriter, loc, typeConverter.convertType(idxType),
50 rewriter.getIntegerAttr(idxType, pos));
51 return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
52 constant);
53 }
54 return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
55}
56
57// Helper that picks the proper sequence for extracting.
58static Value extractOne(ConversionPatternRewriter &rewriter,
59 const LLVMTypeConverter &typeConverter, Location loc,
60 Value val, Type llvmType, int64_t rank, int64_t pos) {
61 if (rank <= 1) {
62 auto idxType = rewriter.getIndexType();
63 auto constant = LLVM::ConstantOp::create(
64 rewriter, loc, typeConverter.convertType(idxType),
65 rewriter.getIntegerAttr(idxType, pos));
66 return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
67 constant);
68 }
69 return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
70}
71
72// Helper that returns data layout alignment of a vector.
73LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
74 VectorType vectorType, unsigned &align) {
75 Type convertedVectorTy = typeConverter.convertType(vectorType);
76 if (!convertedVectorTy)
77 return failure();
78
79 llvm::LLVMContext llvmContext;
80 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
81 .getPreferredAlignment(convertedVectorTy,
82 typeConverter.getDataLayout());
83
84 return success();
85}
86
87// Helper that returns data layout alignment of a memref.
88LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
89 MemRefType memrefType, unsigned &align) {
90 Type elementTy = typeConverter.convertType(memrefType.getElementType());
91 if (!elementTy)
92 return failure();
93
94 // TODO: this should use the MLIR data layout when it becomes available and
95 // stop depending on translation.
96 llvm::LLVMContext llvmContext;
97 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
98 .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
99 return success();
100}
101
102// Helper to resolve the alignment for vector load/store, gather and scatter
103// ops. If useVectorAlignment is true, get the preferred alignment for the
104// vector type in the operation. This option is used for hardware backends with
105// vectorization. Otherwise, use the preferred alignment of the element type of
106// the memref. Note that if you choose to use vector alignment, the shape of the
107// vector type must be resolved before the ConvertVectorToLLVM pass is run.
108LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
109 VectorType vectorType,
110 MemRefType memrefType, unsigned &align,
111 bool useVectorAlignment) {
112 if (useVectorAlignment) {
113 if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
114 return failure();
115 }
116 } else {
117 if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
118 return failure();
119 }
120 }
121 return success();
122}
123
124// Check if the last stride is non-unit and has a valid memory space.
125static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
126 const LLVMTypeConverter &converter) {
127 if (!memRefType.isLastDimUnitStride())
128 return failure();
129 if (failed(converter.getMemRefAddressSpace(memRefType)))
130 return failure();
131 return success();
132}
133
134// Add an index vector component to a base pointer.
135static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
136 const LLVMTypeConverter &typeConverter,
137 MemRefType memRefType, Value llvmMemref, Value base,
138 Value index, VectorType vectorType) {
139 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
140 "unsupported memref type");
141 assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
142 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
143 auto ptrsType =
144 LLVM::getVectorType(pType, vectorType.getDimSize(0),
145 /*isScalable=*/vectorType.getScalableDims()[0]);
146 return LLVM::GEPOp::create(
147 rewriter, loc, ptrsType,
148 typeConverter.convertType(memRefType.getElementType()), base, index);
149}
150
151/// Convert `foldResult` into a Value. Integer attribute is converted to
152/// an LLVM constant op.
154 OpFoldResult foldResult) {
155 if (auto attr = dyn_cast<Attribute>(foldResult)) {
156 auto intAttr = cast<IntegerAttr>(attr);
157 return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
158 }
159
160 return cast<Value>(foldResult);
161}
162
163namespace {
164
165/// Trivial Vector to LLVM conversions
166using VectorScaleOpConversion =
168
169/// Conversion pattern for a vector.bitcast.
170class VectorBitCastOpConversion
171 : public ConvertOpToLLVMPattern<vector::BitCastOp> {
172public:
173 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
174
175 LogicalResult
176 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
177 ConversionPatternRewriter &rewriter) const override {
178 // Only 0-D and 1-D vectors can be lowered to LLVM.
179 VectorType resultTy = bitCastOp.getResultVectorType();
180 if (resultTy.getRank() > 1)
181 return failure();
182 Type newResultTy = typeConverter->convertType(resultTy);
183 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
184 adaptor.getOperands()[0]);
185 return success();
186 }
187};
188
189/// Overloaded utility that replaces a vector.load, vector.store,
190/// vector.maskedload and vector.maskedstore with their respective LLVM
191/// couterparts.
192static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
193 vector::LoadOpAdaptor adaptor,
194 VectorType vectorTy, Value ptr, unsigned align,
195 ConversionPatternRewriter &rewriter) {
196 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
197 /*volatile_=*/false,
198 loadOp.getNontemporal());
199}
200
201static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
202 vector::MaskedLoadOpAdaptor adaptor,
203 VectorType vectorTy, Value ptr, unsigned align,
204 ConversionPatternRewriter &rewriter) {
205 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
206 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
207}
208
209static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
210 vector::StoreOpAdaptor adaptor,
211 VectorType vectorTy, Value ptr, unsigned align,
212 ConversionPatternRewriter &rewriter) {
213 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
214 ptr, align, /*volatile_=*/false,
215 storeOp.getNontemporal());
216}
217
218static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
219 vector::MaskedStoreOpAdaptor adaptor,
220 VectorType vectorTy, Value ptr, unsigned align,
221 ConversionPatternRewriter &rewriter) {
222 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
223 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
224}
225
226/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
227/// vector.maskedstore.
228template <class LoadOrStoreOp>
229class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
230public:
231 explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
232 bool useVectorAlign)
233 : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
234 useVectorAlignment(useVectorAlign) {}
235 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
236
237 LogicalResult
238 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
239 typename LoadOrStoreOp::Adaptor adaptor,
240 ConversionPatternRewriter &rewriter) const override {
241 // Only 1-D vectors can be lowered to LLVM.
242 VectorType vectorTy = loadOrStoreOp.getVectorType();
243 if (vectorTy.getRank() > 1)
244 return failure();
245
246 auto loc = loadOrStoreOp->getLoc();
247 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
248
249 // Resolve alignment.
250 // Explicit alignment takes priority over use-vector-alignment.
251 unsigned align = loadOrStoreOp.getAlignment().value_or(0);
252 if (!align &&
253 failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
254 memRefTy, align, useVectorAlignment)))
255 return rewriter.notifyMatchFailure(loadOrStoreOp,
256 "could not resolve alignment");
257
258 // Resolve address.
259 auto vtype = cast<VectorType>(
260 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
261 Value dataPtr = this->getStridedElementPtr(
262 rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
263 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
264 rewriter);
265 return success();
266 }
267
268private:
269 // If true, use the preferred alignment of the vector type.
270 // If false, use the preferred alignment of the element type
271 // of the memref. This flag is intended for use with hardware
272 // backends that require alignment of vector operations.
273 const bool useVectorAlignment;
274};
275
276/// Conversion pattern for a vector.gather.
277class VectorGatherOpConversion
278 : public ConvertOpToLLVMPattern<vector::GatherOp> {
279public:
280 explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
281 bool useVectorAlign)
282 : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
283 useVectorAlignment(useVectorAlign) {}
284 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
285
286 LogicalResult
287 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override {
289 Location loc = gather->getLoc();
290 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
291 assert(memRefType && "The base should be bufferized");
292
293 // TODO: Add support for strided MemRef.
294 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
295 return rewriter.notifyMatchFailure(gather, "memref type not supported");
296
297 VectorType vType = gather.getVectorType();
298 if (vType.getRank() > 1) {
299 return rewriter.notifyMatchFailure(
300 gather, "only 1-D vectors can be lowered to LLVM");
301 }
302
303 // Resolve alignment.
304 // Explicit alignment takes priority over use-vector-alignment.
305 unsigned align = gather.getAlignment().value_or(0);
306 if (!align &&
307 failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
308 memRefType, align, useVectorAlignment)))
309 return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
310
311 // Resolve address.
312 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
313 adaptor.getBase(), adaptor.getOffsets());
314 Value base = adaptor.getBase();
315 Value ptrs =
316 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
317 base, ptr, adaptor.getIndices(), vType);
318
319 // Replace with the gather intrinsic.
320 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
321 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
322 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
323 return success();
324 }
325
326private:
327 // If true, use the preferred alignment of the vector type.
328 // If false, use the preferred alignment of the element type
329 // of the memref. This flag is intended for use with hardware
330 // backends that require alignment of vector operations.
331 const bool useVectorAlignment;
332};
333
334/// Conversion pattern for a vector.scatter.
335class VectorScatterOpConversion
336 : public ConvertOpToLLVMPattern<vector::ScatterOp> {
337public:
338 explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
339 bool useVectorAlign)
340 : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
341 useVectorAlignment(useVectorAlign) {}
342
343 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
344
345 LogicalResult
346 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
347 ConversionPatternRewriter &rewriter) const override {
348 auto loc = scatter->getLoc();
349 auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
350 assert(memRefType && "The base should be bufferized");
351
352 // TODO: Add support for strided MemRef.
353 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
354 return rewriter.notifyMatchFailure(scatter, "memref type not supported");
355
356 VectorType vType = scatter.getVectorType();
357 if (vType.getRank() > 1) {
358 return rewriter.notifyMatchFailure(
359 scatter, "only 1-D vectors can be lowered to LLVM");
360 }
361
362 // Resolve alignment.
363 // Explicit alignment takes priority over use-vector-alignment.
364 unsigned align = scatter.getAlignment().value_or(0);
365 if (!align &&
366 failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
367 memRefType, align, useVectorAlignment)))
368 return rewriter.notifyMatchFailure(scatter,
369 "could not resolve alignment");
370
371 // Resolve address.
372 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
373 adaptor.getBase(), adaptor.getOffsets());
374 Value ptrs =
375 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
376 adaptor.getBase(), ptr, adaptor.getIndices(), vType);
377
378 // Replace with the scatter intrinsic.
379 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
380 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
381 rewriter.getI32IntegerAttr(align));
382 return success();
383 }
384
385private:
386 // If true, use the preferred alignment of the vector type.
387 // If false, use the preferred alignment of the element type
388 // of the memref. This flag is intended for use with hardware
389 // backends that require alignment of vector operations.
390 const bool useVectorAlignment;
391};
392
393/// Conversion pattern for a vector.expandload.
394class VectorExpandLoadOpConversion
395 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
396public:
397 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
398
399 LogicalResult
400 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
401 ConversionPatternRewriter &rewriter) const override {
402 auto loc = expand->getLoc();
403 MemRefType memRefType = expand.getMemRefType();
404
405 // Resolve address.
406 auto vtype = typeConverter->convertType(expand.getVectorType());
407 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
408 adaptor.getBase(), adaptor.getIndices());
409
410 // From:
411 // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
412 // The pointer alignment defaults to 1.
413 uint64_t alignment = expand.getAlignment().value_or(1);
414
415 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
416 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
417 alignment);
418 return success();
419 }
420};
421
422/// Conversion pattern for a vector.compressstore.
423class VectorCompressStoreOpConversion
424 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
425public:
426 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
427
428 LogicalResult
429 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
430 ConversionPatternRewriter &rewriter) const override {
431 auto loc = compress->getLoc();
432 MemRefType memRefType = compress.getMemRefType();
433
434 // Resolve address.
435 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
436 adaptor.getBase(), adaptor.getIndices());
437
438 // From:
439 // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
440 // The pointer alignment defaults to 1.
441 uint64_t alignment = compress.getAlignment().value_or(1);
442
443 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
444 compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
445 return success();
446 }
447};
448
449/// Reduction neutral classes for overloading.
450class ReductionNeutralZero {};
451class ReductionNeutralIntOne {};
452class ReductionNeutralFPOne {};
453class ReductionNeutralAllOnes {};
454class ReductionNeutralSIntMin {};
455class ReductionNeutralUIntMin {};
456class ReductionNeutralSIntMax {};
457class ReductionNeutralUIntMax {};
458class ReductionNeutralFPMin {};
459class ReductionNeutralFPMax {};
460
461/// Create the reduction neutral zero value.
462static Value createReductionNeutralValue(ReductionNeutralZero neutral,
463 ConversionPatternRewriter &rewriter,
464 Location loc, Type llvmType) {
465 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
466 rewriter.getZeroAttr(llvmType));
467}
468
469/// Create the reduction neutral integer one value.
470static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
471 ConversionPatternRewriter &rewriter,
472 Location loc, Type llvmType) {
473 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
474 rewriter.getIntegerAttr(llvmType, 1));
475}
476
477/// Create the reduction neutral fp one value.
478static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
479 ConversionPatternRewriter &rewriter,
480 Location loc, Type llvmType) {
481 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
482 rewriter.getFloatAttr(llvmType, 1.0));
483}
484
485/// Create the reduction neutral all-ones value.
486static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
487 ConversionPatternRewriter &rewriter,
488 Location loc, Type llvmType) {
489 return LLVM::ConstantOp::create(
490 rewriter, loc, llvmType,
491 rewriter.getIntegerAttr(
492 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
493}
494
495/// Create the reduction neutral signed int minimum value.
496static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
497 ConversionPatternRewriter &rewriter,
498 Location loc, Type llvmType) {
499 return LLVM::ConstantOp::create(
500 rewriter, loc, llvmType,
501 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
502 llvmType.getIntOrFloatBitWidth())));
503}
504
505/// Create the reduction neutral unsigned int minimum value.
506static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
507 ConversionPatternRewriter &rewriter,
508 Location loc, Type llvmType) {
509 return LLVM::ConstantOp::create(
510 rewriter, loc, llvmType,
511 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
512 llvmType.getIntOrFloatBitWidth())));
513}
514
515/// Create the reduction neutral signed int maximum value.
516static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
517 ConversionPatternRewriter &rewriter,
518 Location loc, Type llvmType) {
519 return LLVM::ConstantOp::create(
520 rewriter, loc, llvmType,
521 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
522 llvmType.getIntOrFloatBitWidth())));
523}
524
525/// Create the reduction neutral unsigned int maximum value.
526static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
527 ConversionPatternRewriter &rewriter,
528 Location loc, Type llvmType) {
529 return LLVM::ConstantOp::create(
530 rewriter, loc, llvmType,
531 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
532 llvmType.getIntOrFloatBitWidth())));
533}
534
535/// Create the reduction neutral fp minimum value.
536static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
537 ConversionPatternRewriter &rewriter,
538 Location loc, Type llvmType) {
539 auto floatType = cast<FloatType>(llvmType);
540 return LLVM::ConstantOp::create(
541 rewriter, loc, llvmType,
542 rewriter.getFloatAttr(
543 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
544 /*Negative=*/false)));
545}
546
547/// Create the reduction neutral fp maximum value.
548static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
549 ConversionPatternRewriter &rewriter,
550 Location loc, Type llvmType) {
551 auto floatType = cast<FloatType>(llvmType);
552 return LLVM::ConstantOp::create(
553 rewriter, loc, llvmType,
554 rewriter.getFloatAttr(
555 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
556 /*Negative=*/true)));
557}
558
559/// Returns `accumulator` if it has a valid value. Otherwise, creates and
560/// returns a new accumulator value using `ReductionNeutral`.
561template <class ReductionNeutral>
562static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
563 Location loc, Type llvmType,
564 Value accumulator) {
565 if (accumulator)
566 return accumulator;
567
568 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
569 llvmType);
570}
571
572/// Creates a value with the 1-D vector shape provided in `llvmType`.
573/// This is used as effective vector length by some intrinsics supporting
574/// dynamic vector lengths at runtime.
575static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
576 Location loc, Type llvmType) {
577 VectorType vType = cast<VectorType>(llvmType);
578 auto vShape = vType.getShape();
579 assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
580
581 Value baseVecLength = LLVM::ConstantOp::create(
582 rewriter, loc, rewriter.getI32Type(),
583 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
584
585 if (!vType.getScalableDims()[0])
586 return baseVecLength;
587
588 // For a scalable vector type, create and return `vScale * baseVecLength`.
589 Value vScale = vector::VectorScaleOp::create(rewriter, loc);
590 vScale =
591 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale);
592 Value scalableVecLength =
593 arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
594 return scalableVecLength;
595}
596
597/// Helper method to lower a `vector.reduction` op that performs an arithmetic
598/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
599/// and `ScalarOp` is the scalar operation used to add the accumulation value if
600/// non-null.
601template <class LLVMRedIntrinOp, class ScalarOp>
602static Value createIntegerReductionArithmeticOpLowering(
603 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
604 Value vectorOperand, Value accumulator) {
605
606 Value result =
607 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
608
609 if (accumulator)
610 result = ScalarOp::create(rewriter, loc, accumulator, result);
611 return result;
612}
613
614/// Helper method to lower a `vector.reduction` operation that performs
615/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
616/// intrinsic to use and `predicate` is the predicate to use to compare+combine
617/// the accumulator value if non-null.
618template <class LLVMRedIntrinOp>
619static Value createIntegerReductionComparisonOpLowering(
620 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
621 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
622 Value result =
623 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
624 if (accumulator) {
625 Value cmp =
626 LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator, result);
627 result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator, result);
628 }
629 return result;
630}
631
632namespace {
633template <typename Source>
634struct VectorToScalarMapper;
635template <>
636struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
637 using Type = LLVM::MaximumOp;
638};
639template <>
640struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
641 using Type = LLVM::MinimumOp;
642};
643template <>
644struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
645 using Type = LLVM::MaxNumOp;
646};
647template <>
648struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
649 using Type = LLVM::MinNumOp;
650};
651} // namespace
652
653template <class LLVMRedIntrinOp>
654static Value createFPReductionComparisonOpLowering(
655 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
656 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
657 Value result =
658 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
659
660 if (accumulator) {
661 result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
662 rewriter, loc, result, accumulator);
663 }
664
665 return result;
666}
667
668/// Reduction neutral classes for overloading
669class MaskNeutralFMaximum {};
670class MaskNeutralFMinimum {};
671
672/// Get the mask neutral floating point maximum value
673static llvm::APFloat
674getMaskNeutralValue(MaskNeutralFMaximum,
675 const llvm::fltSemantics &floatSemantics) {
676 return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
677}
678/// Get the mask neutral floating point minimum value
679static llvm::APFloat
680getMaskNeutralValue(MaskNeutralFMinimum,
681 const llvm::fltSemantics &floatSemantics) {
682 return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
683}
684
685/// Create the mask neutral floating point MLIR vector constant
686template <typename MaskNeutral>
687static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
688 Location loc, Type llvmType,
689 Type vectorType) {
690 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
691 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
692 auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
693 return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
694}
695
696/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
697/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
698/// `fmaximum`/`fminimum`.
699/// More information: https://github.com/llvm/llvm-project/issues/64940
700template <class LLVMRedIntrinOp, class MaskNeutral>
701static Value
702lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
703 Location loc, Type llvmType,
704 Value vectorOperand, Value accumulator,
705 Value mask, LLVM::FastmathFlagsAttr fmf) {
706 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
707 rewriter, loc, llvmType, vectorOperand.getType());
708 const Value selectedVectorByMask = LLVM::SelectOp::create(
709 rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
710 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
711 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
712}
713
714template <class LLVMRedIntrinOp, class ReductionNeutral>
715static Value
716lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
717 Type llvmType, Value vectorOperand,
718 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
719 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
720 llvmType, accumulator);
721 return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
722 /*start_value=*/accumulator, vectorOperand,
723 fmf);
724}
725
726/// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
727/// that requires a start value. This start value format spans across fp
728/// reductions without mask and all the masked reduction intrinsics.
729template <class LLVMVPRedIntrinOp, class ReductionNeutral>
730static Value
731lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
732 Location loc, Type llvmType,
733 Value vectorOperand, Value accumulator) {
734 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
735 llvmType, accumulator);
736 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
737 /*startValue=*/accumulator, vectorOperand);
738}
739
740template <class LLVMVPRedIntrinOp, class ReductionNeutral>
741static Value lowerPredicatedReductionWithStartValue(
742 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
743 Value vectorOperand, Value accumulator, Value mask) {
744 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
745 llvmType, accumulator);
746 Value vectorLength =
747 createVectorLengthValue(rewriter, loc, vectorOperand.getType());
748 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
749 /*satrt_value=*/accumulator, vectorOperand,
750 mask, vectorLength);
751}
752
753template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
754 class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
755static Value lowerPredicatedReductionWithStartValue(
756 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
757 Value vectorOperand, Value accumulator, Value mask) {
758 if (llvmType.isIntOrIndex())
759 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
760 IntReductionNeutral>(
761 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
762
763 // FP dispatch.
764 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
765 FPReductionNeutral>(
766 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
767}
768
769/// Conversion pattern for all vector reductions.
770class VectorReductionOpConversion
771 : public ConvertOpToLLVMPattern<vector::ReductionOp> {
772public:
773 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
774 bool reassociateFPRed)
775 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
776 reassociateFPReductions(reassociateFPRed) {}
777
778 LogicalResult
779 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
780 ConversionPatternRewriter &rewriter) const override {
781 auto kind = reductionOp.getKind();
782 Type eltType = reductionOp.getDest().getType();
783 Type llvmType = typeConverter->convertType(eltType);
784 Value operand = adaptor.getVector();
785 Value acc = adaptor.getAcc();
786 Location loc = reductionOp.getLoc();
787
788 if (eltType.isIntOrIndex()) {
789 // Integer reductions: add/mul/min/max/and/or/xor.
790 Value result;
791 switch (kind) {
792 case vector::CombiningKind::ADD:
793 result =
794 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
795 LLVM::AddOp>(
796 rewriter, loc, llvmType, operand, acc);
797 break;
798 case vector::CombiningKind::MUL:
799 result =
800 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
801 LLVM::MulOp>(
802 rewriter, loc, llvmType, operand, acc);
803 break;
804 case vector::CombiningKind::MINUI:
805 result = createIntegerReductionComparisonOpLowering<
806 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
807 LLVM::ICmpPredicate::ule);
808 break;
809 case vector::CombiningKind::MINSI:
810 result = createIntegerReductionComparisonOpLowering<
811 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
812 LLVM::ICmpPredicate::sle);
813 break;
814 case vector::CombiningKind::MAXUI:
815 result = createIntegerReductionComparisonOpLowering<
816 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
817 LLVM::ICmpPredicate::uge);
818 break;
819 case vector::CombiningKind::MAXSI:
820 result = createIntegerReductionComparisonOpLowering<
821 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
822 LLVM::ICmpPredicate::sge);
823 break;
824 case vector::CombiningKind::AND:
825 result =
826 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
827 LLVM::AndOp>(
828 rewriter, loc, llvmType, operand, acc);
829 break;
830 case vector::CombiningKind::OR:
831 result =
832 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
833 LLVM::OrOp>(
834 rewriter, loc, llvmType, operand, acc);
835 break;
836 case vector::CombiningKind::XOR:
837 result =
838 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
839 LLVM::XOrOp>(
840 rewriter, loc, llvmType, operand, acc);
841 break;
842 default:
843 return failure();
844 }
845 rewriter.replaceOp(reductionOp, result);
846
847 return success();
848 }
849
850 if (!isa<FloatType>(eltType))
851 return failure();
852
853 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
854 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
855 reductionOp.getContext(),
856 convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
857 fmf = LLVM::FastmathFlagsAttr::get(
858 reductionOp.getContext(),
859 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
860 : LLVM::FastmathFlags::none));
861
862 // Floating-point reductions: add/mul/min/max
863 Value result;
864 if (kind == vector::CombiningKind::ADD) {
865 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
866 ReductionNeutralZero>(
867 rewriter, loc, llvmType, operand, acc, fmf);
868 } else if (kind == vector::CombiningKind::MUL) {
869 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
870 ReductionNeutralFPOne>(
871 rewriter, loc, llvmType, operand, acc, fmf);
872 } else if (kind == vector::CombiningKind::MINIMUMF) {
873 result =
874 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
875 rewriter, loc, llvmType, operand, acc, fmf);
876 } else if (kind == vector::CombiningKind::MAXIMUMF) {
877 result =
878 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
879 rewriter, loc, llvmType, operand, acc, fmf);
880 } else if (kind == vector::CombiningKind::MINNUMF) {
881 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
882 rewriter, loc, llvmType, operand, acc, fmf);
883 } else if (kind == vector::CombiningKind::MAXNUMF) {
884 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
885 rewriter, loc, llvmType, operand, acc, fmf);
886 } else {
887 return failure();
888 }
889
890 rewriter.replaceOp(reductionOp, result);
891 return success();
892 }
893
894private:
895 const bool reassociateFPReductions;
896};
897
898/// Base class to convert a `vector.mask` operation while matching traits
899/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
900/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
901/// method performs a second match against the maskable operation `MaskedOp`.
902/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
903/// implemented by the concrete conversion classes. This method can match
904/// against specific traits of the `vector.mask` and the maskable operation. It
905/// must replace the `vector.mask` operation.
906template <class MaskedOp>
907class VectorMaskOpConversionBase
908 : public ConvertOpToLLVMPattern<vector::MaskOp> {
909public:
910 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
911
912 LogicalResult
913 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
914 ConversionPatternRewriter &rewriter) const final {
915 // Match against the maskable operation kind.
916 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
917 if (!maskedOp)
918 return failure();
919 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
920 }
921
922protected:
923 virtual LogicalResult
924 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
925 vector::MaskableOpInterface maskableOp,
926 ConversionPatternRewriter &rewriter) const = 0;
927};
928
929class MaskedReductionOpConversion
930 : public VectorMaskOpConversionBase<vector::ReductionOp> {
931
932public:
933 using VectorMaskOpConversionBase<
934 vector::ReductionOp>::VectorMaskOpConversionBase;
935
936 LogicalResult matchAndRewriteMaskableOp(
937 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
938 ConversionPatternRewriter &rewriter) const override {
939 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
940 auto kind = reductionOp.getKind();
941 Type eltType = reductionOp.getDest().getType();
942 Type llvmType = typeConverter->convertType(eltType);
943 Value operand = reductionOp.getVector();
944 Value acc = reductionOp.getAcc();
945 Location loc = reductionOp.getLoc();
946
947 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
948 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
949 reductionOp.getContext(),
950 convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
951
952 Value result;
953 switch (kind) {
954 case vector::CombiningKind::ADD:
955 result = lowerPredicatedReductionWithStartValue<
956 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
957 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
958 maskOp.getMask());
959 break;
960 case vector::CombiningKind::MUL:
961 result = lowerPredicatedReductionWithStartValue<
962 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
963 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
964 maskOp.getMask());
965 break;
966 case vector::CombiningKind::MINUI:
967 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
968 ReductionNeutralUIntMax>(
969 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
970 break;
971 case vector::CombiningKind::MINSI:
972 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
973 ReductionNeutralSIntMax>(
974 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
975 break;
976 case vector::CombiningKind::MAXUI:
977 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
978 ReductionNeutralUIntMin>(
979 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
980 break;
981 case vector::CombiningKind::MAXSI:
982 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
983 ReductionNeutralSIntMin>(
984 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
985 break;
986 case vector::CombiningKind::AND:
987 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
988 ReductionNeutralAllOnes>(
989 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
990 break;
991 case vector::CombiningKind::OR:
992 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
993 ReductionNeutralZero>(
994 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
995 break;
996 case vector::CombiningKind::XOR:
997 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
998 ReductionNeutralZero>(
999 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1000 break;
1001 case vector::CombiningKind::MINNUMF:
1002 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1003 ReductionNeutralFPMax>(
1004 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1005 break;
1006 case vector::CombiningKind::MAXNUMF:
1007 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1008 ReductionNeutralFPMin>(
1009 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1010 break;
1011 case CombiningKind::MAXIMUMF:
1012 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1013 MaskNeutralFMaximum>(
1014 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1015 break;
1016 case CombiningKind::MINIMUMF:
1017 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1018 MaskNeutralFMinimum>(
1019 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1020 break;
1021 }
1022
1023 // Replace `vector.mask` operation altogether.
1024 rewriter.replaceOp(maskOp, result);
1025 return success();
1026 }
1027};
1028
1029class VectorShuffleOpConversion
1030 : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
1031public:
1032 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
1033
1034 LogicalResult
1035 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1036 ConversionPatternRewriter &rewriter) const override {
1037 auto loc = shuffleOp->getLoc();
1038 auto v1Type = shuffleOp.getV1VectorType();
1039 auto v2Type = shuffleOp.getV2VectorType();
1040 auto vectorType = shuffleOp.getResultVectorType();
1041 Type llvmType = typeConverter->convertType(vectorType);
1042 ArrayRef<int64_t> mask = shuffleOp.getMask();
1043
1044 // Bail if result type cannot be lowered.
1045 if (!llvmType)
1046 return failure();
1047
1048 // Get rank and dimension sizes.
1049 int64_t rank = vectorType.getRank();
1050#ifndef NDEBUG
1051 bool wellFormed0DCase =
1052 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1053 bool wellFormedNDCase =
1054 v1Type.getRank() == rank && v2Type.getRank() == rank;
1055 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1056#endif
1057
1058 // For rank 0 and 1, where both operands have *exactly* the same vector
1059 // type, there is direct shuffle support in LLVM. Use it!
1060 if (rank <= 1 && v1Type == v2Type) {
1061 Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
1062 rewriter, loc, adaptor.getV1(), adaptor.getV2(),
1063 llvm::to_vector_of<int32_t>(mask));
1064 rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1065 return success();
1066 }
1067
1068 // For all other cases, insert the individual values individually.
1069 int64_t v1Dim = v1Type.getDimSize(0);
1070 Type eltType;
1071 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1072 eltType = arrayType.getElementType();
1073 else
1074 eltType = cast<VectorType>(llvmType).getElementType();
1075 Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1076 int64_t insPos = 0;
1077 for (int64_t extPos : mask) {
1078 Value value = adaptor.getV1();
1079 if (extPos >= v1Dim) {
1080 extPos -= v1Dim;
1081 value = adaptor.getV2();
1082 }
1083 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1084 eltType, rank, extPos);
1085 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1086 llvmType, rank, insPos++);
1087 }
1088 rewriter.replaceOp(shuffleOp, insert);
1089 return success();
1090 }
1091};
1092
1093class VectorExtractOpConversion
1094 : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1095public:
1096 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1097
1098 LogicalResult
1099 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1100 ConversionPatternRewriter &rewriter) const override {
1101 auto loc = extractOp->getLoc();
1102 auto resultType = extractOp.getResult().getType();
1103 auto llvmResultType = typeConverter->convertType(resultType);
1104 // Bail if result type cannot be lowered.
1105 if (!llvmResultType)
1106 return failure();
1107
1108 SmallVector<OpFoldResult> positionVec = getMixedValues(
1109 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1110
1111 // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1112 // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1113 // from a N-d vector extract to a nested aggregate vector extract in two
1114 // steps:
1115 // - Extract a member from the nested aggregate. The result can be
1116 // a lower rank nested aggregate or a vector (1-D). This is done using
1117 // `llvm.extractvalue`.
1118 // - Extract a scalar out of the vector if needed. This is done using
1119 // `llvm.extractelement`.
1120
1121 // Determine if we need to extract a member out of the aggregate. We
1122 // always need to extract a member if the input rank >= 2.
1123 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1124 // Determine if we need to extract a scalar as the result. We extract
1125 // a scalar if the extract is full rank, i.e., the number of indices is
1126 // equal to source vector rank.
1127 bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
1128 extractOp.getSourceVectorType().getRank();
1129
1130 // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1131 // need to add a position for this change.
1132 if (extractOp.getSourceVectorType().getRank() == 0) {
1133 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1134 positionVec.push_back(rewriter.getZeroAttr(idxType));
1135 }
1136
1137 Value extracted = adaptor.getSource();
1138 if (extractsAggregate) {
1139 ArrayRef<OpFoldResult> position(positionVec);
1140 if (extractsScalar) {
1141 // If we are extracting a scalar from the extracted member, we drop
1142 // the last index, which will be used to extract the scalar out of the
1143 // vector.
1144 position = position.drop_back();
1145 }
1146 // llvm.extractvalue does not support dynamic dimensions.
1147 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1148 return failure();
1149 }
1150 extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
1151 getAsIntegers(position));
1152 }
1153
1154 if (extractsScalar) {
1155 extracted = LLVM::ExtractElementOp::create(
1156 rewriter, loc, extracted,
1157 getAsLLVMValue(rewriter, loc, positionVec.back()));
1158 }
1159
1160 rewriter.replaceOp(extractOp, extracted);
1161 return success();
1162 }
1163};
1164
1165/// Conversion pattern that turns a vector.fma on a 1-D vector
1166/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1167/// This does not match vectors of n >= 2 rank.
1168///
1169/// Example:
1170/// ```
1171/// vector.fma %a, %a, %a : vector<8xf32>
1172/// ```
1173/// is converted to:
1174/// ```
1175/// llvm.intr.fmuladd %va, %va, %va:
1176/// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1177/// -> !llvm."<8 x f32>">
1178/// ```
1179class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1180public:
1181 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1182
1183 LogicalResult
1184 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1185 ConversionPatternRewriter &rewriter) const override {
1186 VectorType vType = fmaOp.getVectorType();
1187 if (vType.getRank() > 1)
1188 return failure();
1189
1190 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1191 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1192 return success();
1193 }
1194};
1195
1196class VectorInsertOpConversion
1197 : public ConvertOpToLLVMPattern<vector::InsertOp> {
1198public:
1199 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1200
1201 LogicalResult
1202 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1203 ConversionPatternRewriter &rewriter) const override {
1204 auto loc = insertOp->getLoc();
1205 auto destVectorType = insertOp.getDestVectorType();
1206 auto llvmResultType = typeConverter->convertType(destVectorType);
1207 // Bail if result type cannot be lowered.
1208 if (!llvmResultType)
1209 return failure();
1210
1211 SmallVector<OpFoldResult> positionVec = getMixedValues(
1212 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1213
1214 // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1215 // its explanatory comment about how N-D vectors are converted as nested
1216 // aggregates (llvm.array's) of 1D vectors.
1217 //
1218 // The innermost dimension of the destination vector, when converted to a
1219 // nested aggregate form, will always be a 1D vector.
1220 //
1221 // * If the insertion is happening into the innermost dimension of the
1222 // destination vector:
1223 // - If the destination is a nested aggregate, extract a 1D vector out of
1224 // the aggregate. This can be done using llvm.extractvalue. The
1225 // destination is now guaranteed to be a 1D vector, to which we are
1226 // inserting.
1227 // - Do the insertion into the 1D destination vector, and make the result
1228 // the new source nested aggregate. This can be done using
1229 // llvm.insertelement.
1230 // * Insert the source nested aggregate into the destination nested
1231 // aggregate.
1232
1233 // Determine if we need to extract/insert a 1D vector out of the aggregate.
1234 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1235 // Determine if we need to insert a scalar into the 1D vector.
1236 bool insertIntoInnermostDim =
1237 static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
1238
1239 ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1240 positionVec.begin(),
1241 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1242 OpFoldResult positionOfScalarWithin1DVector;
1243 if (destVectorType.getRank() == 0) {
1244 // Since the LLVM type converter converts 0D vectors to 1D vectors, we
1245 // need to create a 0 here as the position into the 1D vector.
1246 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1247 positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1248 } else if (insertIntoInnermostDim) {
1249 positionOfScalarWithin1DVector = positionVec.back();
1250 }
1251
1252 // We are going to mutate this 1D vector until it is either the final
1253 // result (in the non-aggregate case) or the value that needs to be
1254 // inserted into the aggregate result.
1255 Value sourceAggregate = adaptor.getValueToStore();
1256 if (insertIntoInnermostDim) {
1257 // Scalar-into-1D-vector case, so we know we will have to create a
1258 // InsertElementOp. The question is into what destination.
1259 if (isNestedAggregate) {
1260 // Aggregate case: the destination for the InsertElementOp needs to be
1261 // extracted from the aggregate.
1262 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1263 llvm::IsaPred<Attribute>)) {
1264 // llvm.extractvalue does not support dynamic dimensions.
1265 return failure();
1266 }
1267 sourceAggregate = LLVM::ExtractValueOp::create(
1268 rewriter, loc, adaptor.getDest(),
1269 getAsIntegers(positionOf1DVectorWithinAggregate));
1270 } else {
1271 // No-aggregate case. The destination for the InsertElementOp is just
1272 // the insertOp's destination.
1273 sourceAggregate = adaptor.getDest();
1274 }
1275 // Insert the scalar into the 1D vector.
1276 sourceAggregate = LLVM::InsertElementOp::create(
1277 rewriter, loc, sourceAggregate.getType(), sourceAggregate,
1278 adaptor.getValueToStore(),
1279 getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
1280 }
1281
1282 Value result = sourceAggregate;
1283 if (isNestedAggregate) {
1284 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1285 llvm::IsaPred<Attribute>)) {
1286 // llvm.insertvalue does not support dynamic dimensions.
1287 return failure();
1288 }
1289 result = LLVM::InsertValueOp::create(
1290 rewriter, loc, adaptor.getDest(), sourceAggregate,
1291 getAsIntegers(positionOf1DVectorWithinAggregate));
1292 }
1293
1294 rewriter.replaceOp(insertOp, result);
1295 return success();
1296 }
1297};
1298
1299/// Lower vector.scalable.insert ops to LLVM vector.insert
1300struct VectorScalableInsertOpLowering
1301 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1302 using ConvertOpToLLVMPattern<
1303 vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1304
1305 LogicalResult
1306 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1307 ConversionPatternRewriter &rewriter) const override {
1308 rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1309 insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1310 return success();
1311 }
1312};
1313
1314/// Lower vector.scalable.extract ops to LLVM vector.extract
1315struct VectorScalableExtractOpLowering
1316 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1317 using ConvertOpToLLVMPattern<
1318 vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1319
1320 LogicalResult
1321 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1322 ConversionPatternRewriter &rewriter) const override {
1323 rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1324 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1325 adaptor.getSource(), adaptor.getPos());
1326 return success();
1327 }
1328};
1329
1330/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1331///
1332/// Example:
1333/// ```
1334/// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1335/// ```
1336/// is rewritten into:
1337/// ```
1338/// %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
1339/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1340/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1341/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1342/// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1343/// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1344/// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1345/// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1346/// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1347/// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1348/// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1349/// // %r3 holds the final value.
1350/// ```
1351class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1352public:
1353 using Base::Base;
1354
1355 void initialize() {
1356 // This pattern recursively unpacks one dimension at a time. The recursion
1357 // bounded as the rank is strictly decreasing.
1358 setHasBoundedRewriteRecursion();
1359 }
1360
1361 LogicalResult matchAndRewrite(FMAOp op,
1362 PatternRewriter &rewriter) const override {
1363 auto vType = op.getVectorType();
1364 if (vType.getRank() < 2)
1365 return failure();
1366
1367 auto loc = op.getLoc();
1368 auto elemType = vType.getElementType();
1369 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
1370 rewriter.getZeroAttr(elemType));
1371 Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
1372 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1373 Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
1374 Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
1375 Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
1376 Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
1377 desc = InsertOp::create(rewriter, loc, fma, desc, i);
1378 }
1379 rewriter.replaceOp(op, desc);
1380 return success();
1381 }
1382};
1383
1384/// Returns the strides if the memory underlying `memRefType` has a contiguous
1385/// static layout.
1386static std::optional<SmallVector<int64_t, 4>>
1387computeContiguousStrides(MemRefType memRefType) {
1388 int64_t offset;
1390 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1391 return std::nullopt;
1392 if (!strides.empty() && strides.back() != 1)
1393 return std::nullopt;
1394 // If no layout or identity layout, this is contiguous by definition.
1395 if (memRefType.getLayout().isIdentity())
1396 return strides;
1397
1398 // Otherwise, we must determine contiguity form shapes. This can only ever
1399 // work in static cases because MemRefType is underspecified to represent
1400 // contiguous dynamic shapes in other ways than with just empty/identity
1401 // layout.
1402 auto sizes = memRefType.getShape();
1403 for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1404 if (ShapedType::isDynamic(sizes[index + 1]) ||
1405 ShapedType::isDynamic(strides[index]) ||
1406 ShapedType::isDynamic(strides[index + 1]))
1407 return std::nullopt;
1408 if (strides[index] != strides[index + 1] * sizes[index + 1])
1409 return std::nullopt;
1410 }
1411 return strides;
1412}
1413
1414class VectorTypeCastOpConversion
1415 : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1416public:
1417 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1418
1419 LogicalResult
1420 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1421 ConversionPatternRewriter &rewriter) const override {
1422 auto loc = castOp->getLoc();
1423 MemRefType sourceMemRefType =
1424 cast<MemRefType>(castOp.getOperand().getType());
1425 MemRefType targetMemRefType = castOp.getType();
1426
1427 // Only static shape casts supported atm.
1428 if (!sourceMemRefType.hasStaticShape() ||
1429 !targetMemRefType.hasStaticShape())
1430 return failure();
1431
1432 auto llvmSourceDescriptorTy =
1433 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1434 if (!llvmSourceDescriptorTy)
1435 return failure();
1436 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1437
1438 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1439 typeConverter->convertType(targetMemRefType));
1440 if (!llvmTargetDescriptorTy)
1441 return failure();
1442
1443 // Only contiguous source buffers supported atm.
1444 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1445 if (!sourceStrides)
1446 return failure();
1447 auto targetStrides = computeContiguousStrides(targetMemRefType);
1448 if (!targetStrides)
1449 return failure();
1450 // Only support static strides for now, regardless of contiguity.
1451 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1452 return failure();
1453
1454 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1455
1456 // Create descriptor.
1457 auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1458 // Set allocated ptr.
1459 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1460 desc.setAllocatedPtr(rewriter, loc, allocated);
1461
1462 // Set aligned ptr.
1463 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1464 desc.setAlignedPtr(rewriter, loc, ptr);
1465 // Fill offset 0.
1466 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1467 auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
1468 desc.setOffset(rewriter, loc, zero);
1469
1470 // Fill size and stride descriptors in memref.
1471 for (const auto &indexedSize :
1472 llvm::enumerate(targetMemRefType.getShape())) {
1473 int64_t index = indexedSize.index();
1474 auto sizeAttr =
1475 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1476 auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
1477 desc.setSize(rewriter, loc, index, size);
1478 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1479 (*targetStrides)[index]);
1480 auto stride =
1481 LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
1482 desc.setStride(rewriter, loc, index, stride);
1483 }
1484
1485 rewriter.replaceOp(castOp, {desc});
1486 return success();
1487 }
1488};
1489
1490/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1491/// Non-scalable versions of this operation are handled in Vector Transforms.
1492class VectorCreateMaskOpConversion
1493 : public OpConversionPattern<vector::CreateMaskOp> {
1494public:
1495 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1496 bool enableIndexOpt)
1497 : OpConversionPattern<vector::CreateMaskOp>(context),
1498 force32BitVectorIndices(enableIndexOpt) {}
1499
1500 LogicalResult
1501 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1502 ConversionPatternRewriter &rewriter) const override {
1503 auto dstType = op.getType();
1504 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1505 return failure();
1506 IntegerType idxType =
1507 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1508 auto loc = op->getLoc();
1509 Value indices = LLVM::StepVectorOp::create(
1510 rewriter, loc,
1511 LLVM::getVectorType(idxType, dstType.getShape()[0],
1512 /*isScalable=*/true));
1513 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1514 adaptor.getOperands()[0]);
1515 Value bounds = BroadcastOp::create(rewriter, loc, indices.getType(), bound);
1516 Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1517 indices, bounds);
1518 rewriter.replaceOp(op, comp);
1519 return success();
1520 }
1521
1522private:
1523 const bool force32BitVectorIndices;
1524};
1525
1526class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1527 SymbolTableCollection *symbolTables = nullptr;
1528
1529public:
1530 explicit VectorPrintOpConversion(
1531 const LLVMTypeConverter &typeConverter,
1532 SymbolTableCollection *symbolTables = nullptr)
1533 : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
1534 symbolTables(symbolTables) {}
1535
1536 // Lowering implementation that relies on a small runtime support library,
1537 // which only needs to provide a few printing methods (single value for all
1538 // data types, opening/closing bracket, comma, newline). The lowering splits
1539 // the vector into elementary printing operations. The advantage of this
1540 // approach is that the library can remain unaware of all low-level
1541 // implementation details of vectors while still supporting output of any
1542 // shaped and dimensioned vector.
1543 //
1544 // Note: This lowering only handles scalars, n-D vectors are broken into
1545 // printing scalars in loops in VectorToSCF.
1546 //
1547 // TODO: rely solely on libc in future? something else?
1548 //
1549 LogicalResult
1550 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1551 ConversionPatternRewriter &rewriter) const override {
1552 auto parent = printOp->getParentOfType<ModuleOp>();
1553 if (!parent)
1554 return failure();
1555
1556 auto loc = printOp->getLoc();
1557
1558 if (auto value = adaptor.getSource()) {
1559 Type printType = printOp.getPrintType();
1560 if (isa<VectorType>(printType)) {
1561 // Vectors should be broken into elementary print ops in VectorToSCF.
1562 return failure();
1563 }
1564 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1565 return failure();
1566 }
1567
1568 auto punct = printOp.getPunctuation();
1569 if (auto stringLiteral = printOp.getStringLiteral()) {
1570 auto createResult =
1571 LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1572 *stringLiteral, *getTypeConverter(),
1573 /*addNewline=*/false);
1574 if (createResult.failed())
1575 return failure();
1576
1577 } else if (punct != PrintPunctuation::NoPunctuation) {
1578 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1579 switch (punct) {
1580 case PrintPunctuation::Close:
1581 return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent,
1582 symbolTables);
1583 case PrintPunctuation::Open:
1584 return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent,
1585 symbolTables);
1586 case PrintPunctuation::Comma:
1587 return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent,
1588 symbolTables);
1589 case PrintPunctuation::NewLine:
1590 return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent,
1591 symbolTables);
1592 default:
1593 llvm_unreachable("unexpected punctuation");
1594 }
1595 }();
1596 if (failed(op))
1597 return failure();
1598 emitCall(rewriter, printOp->getLoc(), op.value());
1599 }
1600
1601 rewriter.eraseOp(printOp);
1602 return success();
1603 }
1604
1605private:
1606 enum class PrintConversion {
1607 // clang-format off
1608 None,
1609 ZeroExt64,
1610 SignExt64,
1611 Bitcast16
1612 // clang-format on
1613 };
1614
1615 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1616 ModuleOp parent, Location loc, Type printType,
1617 Value value) const {
1618 if (typeConverter->convertType(printType) == nullptr)
1619 return failure();
1620
1621 // Make sure element type has runtime support.
1622 PrintConversion conversion = PrintConversion::None;
1623 FailureOr<Operation *> printer;
1624 if (printType.isF32()) {
1625 printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables);
1626 } else if (printType.isF64()) {
1627 printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables);
1628 } else if (printType.isF16()) {
1629 conversion = PrintConversion::Bitcast16; // bits!
1630 printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables);
1631 } else if (printType.isBF16()) {
1632 conversion = PrintConversion::Bitcast16; // bits!
1633 printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables);
1634 } else if (printType.isIndex()) {
1635 printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1636 } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1637 // Integers need a zero or sign extension on the operand
1638 // (depending on the source type) as well as a signed or
1639 // unsigned print method. Up to 64-bit is supported.
1640 unsigned width = intTy.getWidth();
1641 if (intTy.isUnsigned()) {
1642 if (width <= 64) {
1643 if (width < 64)
1644 conversion = PrintConversion::ZeroExt64;
1645 printer =
1646 LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1647 } else {
1648 return failure();
1649 }
1650 } else {
1651 assert(intTy.isSignless() || intTy.isSigned());
1652 if (width <= 64) {
1653 // Note that we *always* zero extend booleans (1-bit integers),
1654 // so that true/false is printed as 1/0 rather than -1/0.
1655 if (width == 1)
1656 conversion = PrintConversion::ZeroExt64;
1657 else if (width < 64)
1658 conversion = PrintConversion::SignExt64;
1659 printer =
1660 LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables);
1661 } else {
1662 return failure();
1663 }
1664 }
1665 } else if (auto floatTy = dyn_cast<FloatType>(printType)) {
1666 // Print other floating-point types using the APFloat runtime library.
1667 int32_t sem =
1668 llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
1669 Value semValue = LLVM::ConstantOp::create(
1670 rewriter, loc, rewriter.getI32Type(),
1671 rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
1672 Value floatBits =
1673 LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
1674 printer =
1675 LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
1676 emitCall(rewriter, loc, printer.value(),
1677 ValueRange({semValue, floatBits}));
1678 return success();
1679 } else {
1680 return failure();
1681 }
1682 if (failed(printer))
1683 return failure();
1684
1685 switch (conversion) {
1686 case PrintConversion::ZeroExt64:
1687 value = arith::ExtUIOp::create(
1688 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1689 break;
1690 case PrintConversion::SignExt64:
1691 value = arith::ExtSIOp::create(
1692 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1693 break;
1694 case PrintConversion::Bitcast16:
1695 value = LLVM::BitcastOp::create(
1696 rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value);
1697 break;
1698 case PrintConversion::None:
1699 break;
1700 }
1701 emitCall(rewriter, loc, printer.value(), value);
1702 return success();
1703 }
1704
1705 // Helper to emit a call.
1706 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1707 Operation *ref, ValueRange params = ValueRange()) {
1708 LLVM::CallOp::create(rewriter, loc, TypeRange(), SymbolRefAttr::get(ref),
1709 params);
1710 }
1711};
1712
1713/// A broadcast of a scalar is lowered to an insertelement + a shufflevector
1714/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
1715/// pattern, the higher rank cases are handled by another pattern.
1716struct VectorBroadcastScalarToLowRankLowering
1717 : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
1718 using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
1719
1720 LogicalResult
1721 matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
1722 ConversionPatternRewriter &rewriter) const override {
1723 if (isa<VectorType>(broadcast.getSourceType()))
1724 return rewriter.notifyMatchFailure(
1725 broadcast, "broadcast from vector type not handled");
1726
1727 VectorType resultType = broadcast.getType();
1728 if (resultType.getRank() > 1)
1729 return rewriter.notifyMatchFailure(broadcast,
1730 "broadcast to 2+-d handled elsewhere");
1731
1732 // First insert it into a poison vector so we can shuffle it.
1733 auto vectorType = typeConverter->convertType(broadcast.getType());
1734 Value poison =
1735 LLVM::PoisonOp::create(rewriter, broadcast.getLoc(), vectorType);
1736 auto zero = LLVM::ConstantOp::create(
1737 rewriter, broadcast.getLoc(),
1738 typeConverter->convertType(rewriter.getIntegerType(32)),
1739 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1740
1741 // For 0-d vector, we simply do `insertelement`.
1742 if (resultType.getRank() == 0) {
1743 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1744 broadcast, vectorType, poison, adaptor.getSource(), zero);
1745 return success();
1746 }
1747
1748 auto v =
1749 LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
1750 poison, adaptor.getSource(), zero);
1751
1752 // For 1-d vector, we additionally do a `shufflevector`.
1753 int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
1754 SmallVector<int32_t> zeroValues(width, 0);
1755
1756 // Shuffle the value across the desired number of elements.
1757 auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
1758 broadcast.getLoc(), v, poison, zeroValues);
1759 rewriter.replaceOp(broadcast, shuffle);
1760 return success();
1761 }
1762};
1763
1764/// The broadcast of a scalar is lowered to an insertelement + a shufflevector
1765/// operation. Only broadcasts to 2+-d vector result types are lowered by this
1766/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
1767/// are not converted to LLVM, only broadcasts from scalars are.
1768struct VectorBroadcastScalarToNdLowering
1769 : public ConvertOpToLLVMPattern<BroadcastOp> {
1770 using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
1771
1772 LogicalResult
1773 matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
1774 ConversionPatternRewriter &rewriter) const override {
1775 if (isa<VectorType>(broadcast.getSourceType()))
1776 return rewriter.notifyMatchFailure(
1777 broadcast, "broadcast from vector type not handled");
1778
1779 VectorType resultType = broadcast.getType();
1780 if (resultType.getRank() <= 1)
1781 return rewriter.notifyMatchFailure(
1782 broadcast, "broadcast to 1-d or 0-d handled elsewhere");
1783
1784 // First insert it into an undef vector so we can shuffle it.
1785 auto loc = broadcast.getLoc();
1786 auto vectorTypeInfo =
1787 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1788 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1789 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1790 if (!llvmNDVectorTy || !llvm1DVectorTy)
1791 return failure();
1792
1793 // Construct returned value.
1794 Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
1795
1796 // Construct a 1-D vector with the broadcasted value that we insert in all
1797 // the places within the returned descriptor.
1798 Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
1799 auto zero = LLVM::ConstantOp::create(
1800 rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1801 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1802 Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
1803 vdesc, adaptor.getSource(), zero);
1804
1805 // Shuffle the value across the desired number of elements.
1806 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1807 SmallVector<int32_t> zeroValues(width, 0);
1808 v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
1809
1810 // Iterate of linear index, convert to coords space and insert broadcasted
1811 // 1-D vector in each position.
1812 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1813 desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
1814 });
1815 rewriter.replaceOp(broadcast, desc);
1816 return success();
1817 }
1818};
1819
1820/// Conversion pattern for a `vector.interleave`.
1821/// This supports fixed-sized vectors and scalable vectors.
1822struct VectorInterleaveOpLowering
1823 : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1825
1826 LogicalResult
1827 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1828 ConversionPatternRewriter &rewriter) const override {
1829 VectorType resultType = interleaveOp.getResultVectorType();
1830 // n-D interleaves should have been lowered already.
1831 if (resultType.getRank() != 1)
1832 return rewriter.notifyMatchFailure(interleaveOp,
1833 "InterleaveOp not rank 1");
1834 // If the result is rank 1, then this directly maps to LLVM.
1835 if (resultType.isScalable()) {
1836 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1837 interleaveOp, typeConverter->convertType(resultType),
1838 adaptor.getLhs(), adaptor.getRhs());
1839 return success();
1840 }
1841 // Lower fixed-size interleaves to a shufflevector. While the
1842 // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1843 // langref still recommends fixed-vectors use shufflevector, see:
1844 // https://llvm.org/docs/LangRef.html#id876.
1845 int64_t resultVectorSize = resultType.getNumElements();
1846 SmallVector<int32_t> interleaveShuffleMask;
1847 interleaveShuffleMask.reserve(resultVectorSize);
1848 for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1849 interleaveShuffleMask.push_back(i);
1850 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1851 }
1852 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1853 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1854 interleaveShuffleMask);
1855 return success();
1856 }
1857};
1858
1859/// Conversion pattern for a `vector.deinterleave`.
1860/// This supports fixed-sized vectors and scalable vectors.
1861struct VectorDeinterleaveOpLowering
1862 : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1864
1865 LogicalResult
1866 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1867 ConversionPatternRewriter &rewriter) const override {
1868 VectorType resultType = deinterleaveOp.getResultVectorType();
1869 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1870 auto loc = deinterleaveOp.getLoc();
1871
1872 // Note: n-D deinterleave operations should be lowered to the 1-D before
1873 // converting to LLVM.
1874 if (resultType.getRank() != 1)
1875 return rewriter.notifyMatchFailure(deinterleaveOp,
1876 "DeinterleaveOp not rank 1");
1877
1878 if (resultType.isScalable()) {
1879 const auto *llvmTypeConverter = this->getTypeConverter();
1880 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1881 auto packedOpResults =
1882 llvmTypeConverter->packOperationResults(deinterleaveResults);
1883 auto intrinsic = LLVM::vector_deinterleave2::create(
1884 rewriter, loc, packedOpResults, adaptor.getSource());
1885
1886 auto evenResult = LLVM::ExtractValueOp::create(
1887 rewriter, loc, intrinsic->getResult(0), 0);
1888 auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
1889 intrinsic->getResult(0), 1);
1890
1891 rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
1892 return success();
1893 }
1894 // Lower fixed-size deinterleave to two shufflevectors. While the
1895 // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1896 // langref still recommends fixed-vectors use shufflevector, see:
1897 // https://llvm.org/docs/LangRef.html#id889.
1898 int64_t resultVectorSize = resultType.getNumElements();
1899 SmallVector<int32_t> evenShuffleMask;
1900 SmallVector<int32_t> oddShuffleMask;
1901
1902 evenShuffleMask.reserve(resultVectorSize);
1903 oddShuffleMask.reserve(resultVectorSize);
1904
1905 for (int i = 0; i < sourceType.getNumElements(); ++i) {
1906 if (i % 2 == 0)
1907 evenShuffleMask.push_back(i);
1908 else
1909 oddShuffleMask.push_back(i);
1910 }
1911
1912 auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
1913 auto evenShuffle = LLVM::ShuffleVectorOp::create(
1914 rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
1915 auto oddShuffle = LLVM::ShuffleVectorOp::create(
1916 rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
1917
1918 rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
1919 return success();
1920 }
1921};
1922
1923/// Conversion pattern for a `vector.from_elements`.
1924struct VectorFromElementsLowering
1925 : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
1927
1928 LogicalResult
1929 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1930 ConversionPatternRewriter &rewriter) const override {
1931 Location loc = fromElementsOp.getLoc();
1932 VectorType vectorType = fromElementsOp.getType();
1933 // Only support 1-D vectors. Multi-dimensional vectors should have been
1934 // transformed to 1-D vectors by the vector-to-vector transformations before
1935 // this.
1936 if (vectorType.getRank() > 1)
1937 return rewriter.notifyMatchFailure(fromElementsOp,
1938 "rank > 1 vectors are not supported");
1939 Type llvmType = typeConverter->convertType(vectorType);
1940 Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1941 Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1942 for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
1943 auto constIdx =
1944 LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
1945 result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
1946 val, constIdx);
1947 }
1948 rewriter.replaceOp(fromElementsOp, result);
1949 return success();
1950 }
1951};
1952
1953/// Conversion pattern for a `vector.to_elements`.
1954struct VectorToElementsLowering
1955 : public ConvertOpToLLVMPattern<vector::ToElementsOp> {
1957
1958 LogicalResult
1959 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1960 ConversionPatternRewriter &rewriter) const override {
1961 Location loc = toElementsOp.getLoc();
1962 auto idxType = typeConverter->convertType(rewriter.getIndexType());
1963 Value source = adaptor.getSource();
1964
1965 SmallVector<Value> results(toElementsOp->getNumResults());
1966 for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1967 // Create an extractelement operation only for results that are not dead.
1968 if (element.use_empty())
1969 continue;
1970
1971 auto constIdx = LLVM::ConstantOp::create(
1972 rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx));
1973 auto llvmType = typeConverter->convertType(element.getType());
1974
1975 Value result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
1976 source, constIdx);
1977 results[idx] = result;
1978 }
1979
1980 rewriter.replaceOp(toElementsOp, results);
1981 return success();
1982 }
1983};
1984
1985/// Conversion pattern for vector.step.
1986struct VectorScalableStepOpLowering
1987 : public ConvertOpToLLVMPattern<vector::StepOp> {
1989
1990 LogicalResult
1991 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1992 ConversionPatternRewriter &rewriter) const override {
1993 auto resultType = cast<VectorType>(stepOp.getType());
1994 if (!resultType.isScalable()) {
1995 return failure();
1996 }
1997 Type llvmType = typeConverter->convertType(stepOp.getType());
1998 rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
1999 return success();
2000 }
2001};
2002
2003/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
2004/// semantics to:
2005/// ```
2006/// %flattened_a = vector.shape_cast %a
2007/// %flattened_b = vector.shape_cast %b
2008/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
2009/// %d = vector.shape_cast %%flattened_d
2010/// %e = add %c, %d
2011/// ```
2012/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
2013class ContractionOpToMatmulOpLowering
2014 : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
2015public:
2016 using MaskableOpRewritePattern::MaskableOpRewritePattern;
2017
2018 ContractionOpToMatmulOpLowering(MLIRContext *context,
2019 PatternBenefit benefit = 100)
2020 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
2021
2022 FailureOr<Value>
2023 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2024 PatternRewriter &rewriter) const override;
2025};
2026
2027/// Lower a qualifying `vector.contract %a, %b, %c` (with row-major matmul
2028/// semantics directly into `llvm.intr.matrix.multiply`:
2029/// BEFORE:
2030/// ```mlir
2031/// %res = vector.contract #matmat_trait %lhs, %rhs, %acc
2032/// : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
2033/// ```
2034///
2035/// AFTER:
2036/// ```mlir
2037/// %lhs = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
2038/// %rhs = vector.shape_cast %arg1 : vector<4x3xf32> to vector<12xf32>
2039/// %matmul = llvm.intr.matrix.multiply %lhs, %rhs
2040/// %res = arith.addf %acc, %matmul : vector<2x3xf32>
2041/// ```
2042//
2043/// Scalable vectors are not supported.
2044FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
2045 vector::ContractionOp op, MaskingOpInterface maskOp,
2046 PatternRewriter &rew) const {
2047 // TODO: Support vector.mask.
2048 if (maskOp)
2049 return failure();
2050
2051 auto iteratorTypes = op.getIteratorTypes().getValue();
2052 if (!isParallelIterator(iteratorTypes[0]) ||
2053 !isParallelIterator(iteratorTypes[1]) ||
2054 !isReductionIterator(iteratorTypes[2]))
2055 return failure();
2056
2057 Type opResType = op.getType();
2058 VectorType vecType = dyn_cast<VectorType>(opResType);
2059 if (vecType && vecType.isScalable()) {
2060 // Note - this is sufficient to reject all cases with scalable vectors.
2061 return failure();
2062 }
2063
2064 Type elementType = op.getLhsType().getElementType();
2065 if (!elementType.isIntOrFloat())
2066 return failure();
2067
2068 Type dstElementType = vecType ? vecType.getElementType() : opResType;
2069 if (elementType != dstElementType)
2070 return failure();
2071
2072 // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
2073 // Bail out if the contraction cannot be put in this form.
2074 MLIRContext *ctx = op.getContext();
2075 Location loc = op.getLoc();
2076 AffineExpr m, n, k;
2077 bindDims(rew.getContext(), m, n, k);
2078 // LHS must be A(m, k) or A(k, m).
2079 Value lhs = op.getLhs();
2080 auto lhsMap = op.getIndexingMapsArray()[0];
2081 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
2082 lhs = vector::TransposeOp::create(rew, loc, lhs, ArrayRef<int64_t>{1, 0});
2083 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
2084 return failure();
2085
2086 // RHS must be B(k, n) or B(n, k).
2087 Value rhs = op.getRhs();
2088 auto rhsMap = op.getIndexingMapsArray()[1];
2089 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
2090 rhs = vector::TransposeOp::create(rew, loc, rhs, ArrayRef<int64_t>{1, 0});
2091 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
2092 return failure();
2093
2094 // At this point lhs and rhs are in row-major.
2095 VectorType lhsType = cast<VectorType>(lhs.getType());
2096 VectorType rhsType = cast<VectorType>(rhs.getType());
2097 int64_t lhsRows = lhsType.getDimSize(0);
2098 int64_t lhsColumns = lhsType.getDimSize(1);
2099 int64_t rhsColumns = rhsType.getDimSize(1);
2100
2101 Type flattenedLHSType =
2102 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
2103 lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType, lhs);
2104
2105 Type flattenedRHSType =
2106 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
2107 rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType, rhs);
2108
2109 Value mul = LLVM::MatrixMultiplyOp::create(
2110 rew, loc,
2111 VectorType::get(lhsRows * rhsColumns,
2112 cast<VectorType>(lhs.getType()).getElementType()),
2113 lhs, rhs, lhsRows, lhsColumns, rhsColumns);
2114
2115 mul = vector::ShapeCastOp::create(
2116 rew, loc,
2117 VectorType::get({lhsRows, rhsColumns},
2118 getElementTypeOrSelf(op.getAcc().getType())),
2119 mul);
2120
2121 // ACC must be C(m, n) or C(n, m).
2122 auto accMap = op.getIndexingMapsArray()[2];
2123 if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
2124 mul = vector::TransposeOp::create(rew, loc, mul, ArrayRef<int64_t>{1, 0});
2125 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
2126 llvm_unreachable("invalid contraction semantics");
2127
2128 Value res = isa<IntegerType>(elementType)
2129 ? static_cast<Value>(
2130 arith::AddIOp::create(rew, loc, op.getAcc(), mul))
2131 : static_cast<Value>(
2132 arith::AddFOp::create(rew, loc, op.getAcc(), mul));
2133
2134 return res;
2135}
2136
2137/// Lowers vector.transpose directly to llvm.intr.matrix.transpose
2138///
2139/// BEFORE:
2140/// ```mlir
2141/// %tr = vector.transpose %vec, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
2142/// ```
2143/// AFTER:
2144/// ```mlir
2145/// %vec_cs = vector.shape_cast %vec : vector<2x4xf32> to vector<8xf32>
2146/// %tr = llvm.intr.matrix.transpose %vec_sc
2147/// {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> into vector<8xf32>
2148/// %res = vector.shape_cast %tr : vector<8xf32> to vector<4x2xf32>
2149/// ```
2150class TransposeOpToMatrixTransposeOpLowering
2151 : public OpRewritePattern<vector::TransposeOp> {
2152public:
2153 using Base::Base;
2154
2155 LogicalResult matchAndRewrite(vector::TransposeOp op,
2156 PatternRewriter &rewriter) const override {
2157 auto loc = op.getLoc();
2158
2159 Value input = op.getVector();
2160 VectorType inputType = op.getSourceVectorType();
2161 VectorType resType = op.getResultVectorType();
2162
2163 if (inputType.isScalable())
2164 return rewriter.notifyMatchFailure(
2165 op, "This lowering does not support scalable vectors");
2166
2167 // Set up convenience transposition table.
2168 ArrayRef<int64_t> transp = op.getPermutation();
2169
2170 if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
2171 return failure();
2172 }
2173
2174 Type flattenedType =
2175 VectorType::get(resType.getNumElements(), resType.getElementType());
2176 auto matrix =
2177 vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
2178 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
2179 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
2180 Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
2181 matrix, rows, columns);
2182 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
2183 return success();
2184 }
2185};
2186
2187} // namespace
2188
2190 RewritePatternSet &patterns) {
2191 patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
2192}
2193
2195 RewritePatternSet &patterns, PatternBenefit benefit) {
2196 patterns.add<ContractionOpToMatmulOpLowering>(patterns.getContext(), benefit);
2197}
2198
2200 RewritePatternSet &patterns, PatternBenefit benefit) {
2201 patterns.add<TransposeOpToMatrixTransposeOpLowering>(patterns.getContext(),
2202 benefit);
2203}
2204
2205/// Populate the given list with patterns that convert from Vector to LLVM.
2207 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2208 bool reassociateFPReductions, bool force32BitVectorIndices,
2209 bool useVectorAlignment) {
2210 // This function populates only ConversionPatterns, not RewritePatterns.
2211 MLIRContext *ctx = converter.getDialect()->getContext();
2212 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2213 patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2214 patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2215 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2216 VectorLoadStoreConversion<vector::StoreOp>,
2217 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2218 VectorGatherOpConversion, VectorScatterOpConversion>(
2219 converter, useVectorAlignment);
2220 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2221 VectorExtractOpConversion, VectorFMAOp1DConversion,
2222 VectorInsertOpConversion, VectorPrintOpConversion,
2223 VectorTypeCastOpConversion, VectorScaleOpConversion,
2224 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2225 VectorBroadcastScalarToLowRankLowering,
2226 VectorBroadcastScalarToNdLowering,
2227 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2228 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2229 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2230 VectorToElementsLowering, VectorScalableStepOpLowering>(
2231 converter);
2232}
2233
2234namespace {
2235struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2236 VectorToLLVMDialectInterface(Dialect *dialect)
2237 : ConvertToLLVMPatternInterface(dialect) {}
2238
2239 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
2240 void loadDependentDialects(MLIRContext *context) const final {
2241 context->loadDialect<LLVM::LLVMDialect>();
2242 }
2243
2244 /// Hook for derived dialect interface to provide conversion patterns
2245 /// and mark dialect legal for the conversion target.
2246 void populateConvertToLLVMConversionPatterns(
2247 ConversionTarget &target, LLVMTypeConverter &typeConverter,
2248 RewritePatternSet &patterns) const final {
2249 populateVectorToLLVMConversionPatterns(typeConverter, patterns);
2250 }
2251};
2252} // namespace
2253
2255 DialectRegistry &registry) {
2256 registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
2257 dialect->addInterfaces<VectorToLLVMDialectInterface>();
2258 });
2259}
return success()
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
lhs
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition Unit.cpp:18
#define mul(a, b)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
MLIRContext * getContext() const
Definition Builders.h:56
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
Conversion from types to the LLVM IR dialect.
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition Pattern.h:330
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:603
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition VectorOps.h:156
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
void registerConvertVectorToLLVMInterface(DialectRegistry &registry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:122
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.