MLIR 23.0.0git
ConvertVectorToLLVM.cpp
Go to the documentation of this file.
1//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
31#include "llvm/ADT/APFloat.h"
32#include "llvm/IR/LLVMContext.h"
33#include "llvm/Support/Casting.h"
34
35#include <optional>
36
37using namespace mlir;
38using namespace mlir::vector;
39
40// Helper that picks the proper sequence for inserting.
41static Value insertOne(ConversionPatternRewriter &rewriter,
42 const LLVMTypeConverter &typeConverter, Location loc,
43 Value val1, Value val2, Type llvmType, int64_t rank,
44 int64_t pos) {
45 assert(rank > 0 && "0-D vector corner case should have been handled already");
46 if (rank == 1) {
47 auto idxType = rewriter.getIndexType();
48 auto constant = LLVM::ConstantOp::create(
49 rewriter, loc, typeConverter.convertType(idxType),
50 rewriter.getIntegerAttr(idxType, pos));
51 return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
52 constant);
53 }
54 return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
55}
56
57// Helper that picks the proper sequence for extracting.
58static Value extractOne(ConversionPatternRewriter &rewriter,
59 const LLVMTypeConverter &typeConverter, Location loc,
60 Value val, Type llvmType, int64_t rank, int64_t pos) {
61 if (rank <= 1) {
62 auto idxType = rewriter.getIndexType();
63 auto constant = LLVM::ConstantOp::create(
64 rewriter, loc, typeConverter.convertType(idxType),
65 rewriter.getIntegerAttr(idxType, pos));
66 return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
67 constant);
68 }
69 return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
70}
71
72// Helper that returns data layout alignment of a vector.
73LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
74 VectorType vectorType, unsigned &align) {
75 Type convertedVectorTy = typeConverter.convertType(vectorType);
76 if (!convertedVectorTy)
77 return failure();
78
79 llvm::LLVMContext llvmContext;
80 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
81 .getPreferredAlignment(convertedVectorTy,
82 typeConverter.getDataLayout());
83
84 return success();
85}
86
87// Helper that returns data layout alignment of a memref.
88LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
89 MemRefType memrefType, unsigned &align) {
90 Type elementTy = typeConverter.convertType(memrefType.getElementType());
91 if (!elementTy)
92 return failure();
93
94 // TODO: this should use the MLIR data layout when it becomes available and
95 // stop depending on translation.
96 llvm::LLVMContext llvmContext;
97 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
98 .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
99 return success();
100}
101
102// Helper to resolve the alignment for vector load/store, gather and scatter
103// ops. If useVectorAlignment is true, get the preferred alignment for the
104// vector type in the operation. This option is used for hardware backends with
105// vectorization. Otherwise, use the preferred alignment of the element type of
106// the memref. Note that if you choose to use vector alignment, the shape of the
107// vector type must be resolved before the ConvertVectorToLLVM pass is run.
108LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
109 VectorType vectorType,
110 MemRefType memrefType, unsigned &align,
111 bool useVectorAlignment) {
112 if (useVectorAlignment) {
113 if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
114 return failure();
115 }
116 } else {
117 if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
118 return failure();
119 }
120 }
121 return success();
122}
123
124// Check if the last stride is non-unit and has a valid memory space.
125static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
126 const LLVMTypeConverter &converter) {
127 if (!memRefType.isLastDimUnitStride())
128 return failure();
129 if (failed(converter.getMemRefAddressSpace(memRefType)))
130 return failure();
131 return success();
132}
133
134// Add an index vector component to a base pointer.
135static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
136 const LLVMTypeConverter &typeConverter,
137 MemRefType memRefType, Value llvmMemref, Value base,
138 Value index, VectorType vectorType) {
139 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
140 "unsupported memref type");
141 assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
142 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
143 auto ptrsType =
144 LLVM::getVectorType(pType, vectorType.getDimSize(0),
145 /*isScalable=*/vectorType.getScalableDims()[0]);
146 return LLVM::GEPOp::create(
147 rewriter, loc, ptrsType,
148 typeConverter.convertType(memRefType.getElementType()), base, index);
149}
150
151/// Convert `foldResult` into a Value. Integer attribute is converted to
152/// an LLVM constant op.
154 OpFoldResult foldResult) {
155 if (auto attr = dyn_cast<Attribute>(foldResult)) {
156 auto intAttr = cast<IntegerAttr>(attr);
157 return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
158 }
159
160 return cast<Value>(foldResult);
161}
162
163namespace {
164
165/// Trivial Vector to LLVM conversions
166using VectorScaleOpConversion =
168
169/// Conversion pattern for a vector.bitcast.
170class VectorBitCastOpConversion
171 : public ConvertOpToLLVMPattern<vector::BitCastOp> {
172public:
173 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
174
175 LogicalResult
176 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
177 ConversionPatternRewriter &rewriter) const override {
178 // Only 0-D and 1-D vectors can be lowered to LLVM.
179 VectorType resultTy = bitCastOp.getResultVectorType();
180 if (resultTy.getRank() > 1)
181 return failure();
182 Type newResultTy = typeConverter->convertType(resultTy);
183 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
184 adaptor.getOperands()[0]);
185 return success();
186 }
187};
188
189/// Overloaded utility that replaces a vector.load, vector.store,
190/// vector.maskedload and vector.maskedstore with their respective LLVM
191/// couterparts.
192static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
193 vector::LoadOpAdaptor adaptor,
194 VectorType vectorTy, Value ptr, unsigned align,
195 ConversionPatternRewriter &rewriter) {
196 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
197 /*volatile_=*/false,
198 loadOp.getNontemporal());
199}
200
201static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
202 vector::MaskedLoadOpAdaptor adaptor,
203 VectorType vectorTy, Value ptr, unsigned align,
204 ConversionPatternRewriter &rewriter) {
205 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
206 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
207}
208
209static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
210 vector::StoreOpAdaptor adaptor,
211 VectorType vectorTy, Value ptr, unsigned align,
212 ConversionPatternRewriter &rewriter) {
213 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
214 ptr, align, /*volatile_=*/false,
215 storeOp.getNontemporal());
216}
217
218static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
219 vector::MaskedStoreOpAdaptor adaptor,
220 VectorType vectorTy, Value ptr, unsigned align,
221 ConversionPatternRewriter &rewriter) {
222 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
223 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
224}
225
226/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
227/// vector.maskedstore.
228template <class LoadOrStoreOp>
229class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
230public:
231 explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
232 bool useVectorAlign)
233 : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
234 useVectorAlignment(useVectorAlign) {}
235 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
236
237 LogicalResult
238 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
239 typename LoadOrStoreOp::Adaptor adaptor,
240 ConversionPatternRewriter &rewriter) const override {
241 // Only 1-D vectors can be lowered to LLVM.
242 VectorType vectorTy = loadOrStoreOp.getVectorType();
243 if (vectorTy.getRank() > 1)
244 return failure();
245
246 auto loc = loadOrStoreOp->getLoc();
247 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
248
249 // Resolve alignment.
250 // Explicit alignment takes priority over use-vector-alignment.
251 unsigned align = loadOrStoreOp.getAlignment().value_or(0);
252 if (!align &&
253 failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
254 memRefTy, align, useVectorAlignment)))
255 return rewriter.notifyMatchFailure(loadOrStoreOp,
256 "could not resolve alignment");
257
258 // Resolve address.
259 auto vtype = cast<VectorType>(
260 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
261 Value dataPtr = this->getStridedElementPtr(
262 rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
263 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
264 rewriter);
265 return success();
266 }
267
268private:
269 // If true, use the preferred alignment of the vector type.
270 // If false, use the preferred alignment of the element type
271 // of the memref. This flag is intended for use with hardware
272 // backends that require alignment of vector operations.
273 const bool useVectorAlignment;
274};
275
276/// Conversion pattern for a vector.gather.
277class VectorGatherOpConversion
278 : public ConvertOpToLLVMPattern<vector::GatherOp> {
279public:
280 explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
281 bool useVectorAlign)
282 : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
283 useVectorAlignment(useVectorAlign) {}
284 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
285
286 LogicalResult
287 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override {
289 Location loc = gather->getLoc();
290 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
291 assert(memRefType && "The base should be bufferized");
292
293 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
294 return rewriter.notifyMatchFailure(gather, "memref type not supported");
295
296 VectorType vType = gather.getVectorType();
297 if (vType.getRank() > 1) {
298 return rewriter.notifyMatchFailure(
299 gather, "only 1-D vectors can be lowered to LLVM");
300 }
301
302 // Resolve alignment.
303 // Explicit alignment takes priority over use-vector-alignment.
304 unsigned align = gather.getAlignment().value_or(0);
305 if (!align &&
306 failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
307 memRefType, align, useVectorAlignment)))
308 return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
309
310 // Resolve address.
311 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
312 adaptor.getBase(), adaptor.getOffsets());
313 Value base = adaptor.getBase();
314 Value ptrs =
315 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
316 base, ptr, adaptor.getIndices(), vType);
317
318 // Replace with the gather intrinsic.
319 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
320 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
321 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
322 return success();
323 }
324
325private:
326 // If true, use the preferred alignment of the vector type.
327 // If false, use the preferred alignment of the element type
328 // of the memref. This flag is intended for use with hardware
329 // backends that require alignment of vector operations.
330 const bool useVectorAlignment;
331};
332
333/// Conversion pattern for a vector.scatter.
334class VectorScatterOpConversion
335 : public ConvertOpToLLVMPattern<vector::ScatterOp> {
336public:
337 explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
338 bool useVectorAlign)
339 : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
340 useVectorAlignment(useVectorAlign) {}
341
342 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
343
344 LogicalResult
345 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
346 ConversionPatternRewriter &rewriter) const override {
347 auto loc = scatter->getLoc();
348 auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
349 assert(memRefType && "The base should be bufferized");
350
351 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
352 return rewriter.notifyMatchFailure(scatter, "memref type not supported");
353
354 VectorType vType = scatter.getVectorType();
355 if (vType.getRank() > 1) {
356 return rewriter.notifyMatchFailure(
357 scatter, "only 1-D vectors can be lowered to LLVM");
358 }
359
360 // Resolve alignment.
361 // Explicit alignment takes priority over use-vector-alignment.
362 unsigned align = scatter.getAlignment().value_or(0);
363 if (!align &&
364 failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
365 memRefType, align, useVectorAlignment)))
366 return rewriter.notifyMatchFailure(scatter,
367 "could not resolve alignment");
368
369 // Resolve address.
370 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
371 adaptor.getBase(), adaptor.getOffsets());
372 Value ptrs =
373 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
374 adaptor.getBase(), ptr, adaptor.getIndices(), vType);
375
376 // Replace with the scatter intrinsic.
377 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
378 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
379 rewriter.getI32IntegerAttr(align));
380 return success();
381 }
382
383private:
384 // If true, use the preferred alignment of the vector type.
385 // If false, use the preferred alignment of the element type
386 // of the memref. This flag is intended for use with hardware
387 // backends that require alignment of vector operations.
388 const bool useVectorAlignment;
389};
390
391/// Conversion pattern for a vector.expandload.
392class VectorExpandLoadOpConversion
393 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
394public:
395 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
396
397 LogicalResult
398 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
399 ConversionPatternRewriter &rewriter) const override {
400 auto loc = expand->getLoc();
401 MemRefType memRefType = expand.getMemRefType();
402
403 // Resolve address.
404 auto vtype = typeConverter->convertType(expand.getVectorType());
405 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
406 adaptor.getBase(), adaptor.getIndices());
407
408 // From:
409 // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
410 // The pointer alignment defaults to 1.
411 uint64_t alignment = expand.getAlignment().value_or(1);
412
413 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
414 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
415 alignment);
416 return success();
417 }
418};
419
420/// Conversion pattern for a vector.compressstore.
421class VectorCompressStoreOpConversion
422 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
423public:
424 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
425
426 LogicalResult
427 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter) const override {
429 auto loc = compress->getLoc();
430 MemRefType memRefType = compress.getMemRefType();
431
432 // Resolve address.
433 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
434 adaptor.getBase(), adaptor.getIndices());
435
436 // From:
437 // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
438 // The pointer alignment defaults to 1.
439 uint64_t alignment = compress.getAlignment().value_or(1);
440
441 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
442 compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
443 return success();
444 }
445};
446
447/// Reduction neutral classes for overloading.
448class ReductionNeutralZero {};
449class ReductionNeutralIntOne {};
450class ReductionNeutralFPOne {};
451class ReductionNeutralAllOnes {};
452class ReductionNeutralSIntMin {};
453class ReductionNeutralUIntMin {};
454class ReductionNeutralSIntMax {};
455class ReductionNeutralUIntMax {};
456class ReductionNeutralFPMin {};
457class ReductionNeutralFPMax {};
458
459/// Create the reduction neutral zero value.
460static Value createReductionNeutralValue(ReductionNeutralZero neutral,
461 ConversionPatternRewriter &rewriter,
462 Location loc, Type llvmType) {
463 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
464 rewriter.getZeroAttr(llvmType));
465}
466
467/// Create the reduction neutral integer one value.
468static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
469 ConversionPatternRewriter &rewriter,
470 Location loc, Type llvmType) {
471 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
472 rewriter.getIntegerAttr(llvmType, 1));
473}
474
475/// Create the reduction neutral fp one value.
476static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
477 ConversionPatternRewriter &rewriter,
478 Location loc, Type llvmType) {
479 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
480 rewriter.getFloatAttr(llvmType, 1.0));
481}
482
483/// Create the reduction neutral all-ones value.
484static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
485 ConversionPatternRewriter &rewriter,
486 Location loc, Type llvmType) {
487 return LLVM::ConstantOp::create(
488 rewriter, loc, llvmType,
489 rewriter.getIntegerAttr(
490 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
491}
492
493/// Create the reduction neutral signed int minimum value.
494static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
495 ConversionPatternRewriter &rewriter,
496 Location loc, Type llvmType) {
497 return LLVM::ConstantOp::create(
498 rewriter, loc, llvmType,
499 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
500 llvmType.getIntOrFloatBitWidth())));
501}
502
503/// Create the reduction neutral unsigned int minimum value.
504static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
505 ConversionPatternRewriter &rewriter,
506 Location loc, Type llvmType) {
507 return LLVM::ConstantOp::create(
508 rewriter, loc, llvmType,
509 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
510 llvmType.getIntOrFloatBitWidth())));
511}
512
513/// Create the reduction neutral signed int maximum value.
514static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
515 ConversionPatternRewriter &rewriter,
516 Location loc, Type llvmType) {
517 return LLVM::ConstantOp::create(
518 rewriter, loc, llvmType,
519 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
520 llvmType.getIntOrFloatBitWidth())));
521}
522
523/// Create the reduction neutral unsigned int maximum value.
524static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
525 ConversionPatternRewriter &rewriter,
526 Location loc, Type llvmType) {
527 return LLVM::ConstantOp::create(
528 rewriter, loc, llvmType,
529 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
530 llvmType.getIntOrFloatBitWidth())));
531}
532
533/// Create the reduction neutral fp minimum value.
534static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
535 ConversionPatternRewriter &rewriter,
536 Location loc, Type llvmType) {
537 auto floatType = cast<FloatType>(llvmType);
538 return LLVM::ConstantOp::create(
539 rewriter, loc, llvmType,
540 rewriter.getFloatAttr(
541 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
542 /*Negative=*/false)));
543}
544
545/// Create the reduction neutral fp maximum value.
546static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
547 ConversionPatternRewriter &rewriter,
548 Location loc, Type llvmType) {
549 auto floatType = cast<FloatType>(llvmType);
550 return LLVM::ConstantOp::create(
551 rewriter, loc, llvmType,
552 rewriter.getFloatAttr(
553 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
554 /*Negative=*/true)));
555}
556
557/// Returns `accumulator` if it has a valid value. Otherwise, creates and
558/// returns a new accumulator value using `ReductionNeutral`.
559template <class ReductionNeutral>
560static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
561 Location loc, Type llvmType,
562 Value accumulator) {
563 if (accumulator)
564 return accumulator;
565
566 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
567 llvmType);
568}
569
570/// Creates a value with the 1-D vector shape provided in `llvmType`.
571/// This is used as effective vector length by some intrinsics supporting
572/// dynamic vector lengths at runtime.
573static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
574 Location loc, Type llvmType) {
575 VectorType vType = cast<VectorType>(llvmType);
576 auto vShape = vType.getShape();
577 assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
578
579 Value baseVecLength = LLVM::ConstantOp::create(
580 rewriter, loc, rewriter.getI32Type(),
581 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
582
583 if (!vType.getScalableDims()[0])
584 return baseVecLength;
585
586 // For a scalable vector type, create and return `vScale * baseVecLength`.
587 Value vScale = vector::VectorScaleOp::create(rewriter, loc);
588 vScale =
589 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale);
590 Value scalableVecLength =
591 arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
592 return scalableVecLength;
593}
594
595/// Helper method to lower a `vector.reduction` op that performs an arithmetic
596/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
597/// and `ScalarOp` is the scalar operation used to add the accumulation value if
598/// non-null.
599template <class LLVMRedIntrinOp, class ScalarOp>
600static Value createIntegerReductionArithmeticOpLowering(
601 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
602 Value vectorOperand, Value accumulator) {
603
604 Value result =
605 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
606
607 if (accumulator)
608 result = ScalarOp::create(rewriter, loc, accumulator, result);
609 return result;
610}
611
612/// Helper method to lower a `vector.reduction` operation that performs
613/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
614/// intrinsic to use and `predicate` is the predicate to use to compare+combine
615/// the accumulator value if non-null.
616template <class LLVMRedIntrinOp>
617static Value createIntegerReductionComparisonOpLowering(
618 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
619 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
620 Value result =
621 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
622 if (accumulator) {
623 Value cmp =
624 LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator, result);
625 result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator, result);
626 }
627 return result;
628}
629
630namespace {
631template <typename Source>
632struct VectorToScalarMapper;
633template <>
634struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
635 using Type = LLVM::MaximumOp;
636};
637template <>
638struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
639 using Type = LLVM::MinimumOp;
640};
641template <>
642struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
643 using Type = LLVM::MaxNumOp;
644};
645template <>
646struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
647 using Type = LLVM::MinNumOp;
648};
649} // namespace
650
651template <class LLVMRedIntrinOp>
652static Value createFPReductionComparisonOpLowering(
653 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
654 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
655 Value result =
656 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
657
658 if (accumulator) {
659 result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
660 rewriter, loc, result, accumulator);
661 }
662
663 return result;
664}
665
666/// Reduction neutral classes for overloading
667class MaskNeutralFMaximum {};
668class MaskNeutralFMinimum {};
669
670/// Get the mask neutral floating point maximum value
671static llvm::APFloat
672getMaskNeutralValue(MaskNeutralFMaximum,
673 const llvm::fltSemantics &floatSemantics) {
674 return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
675}
676/// Get the mask neutral floating point minimum value
677static llvm::APFloat
678getMaskNeutralValue(MaskNeutralFMinimum,
679 const llvm::fltSemantics &floatSemantics) {
680 return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
681}
682
683/// Create the mask neutral floating point MLIR vector constant
684template <typename MaskNeutral>
685static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
686 Location loc, Type llvmType,
687 Type vectorType) {
688 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
689 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
690 auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
691 return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
692}
693
694/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
695/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
696/// `fmaximum`/`fminimum`.
697/// More information: https://github.com/llvm/llvm-project/issues/64940
698template <class LLVMRedIntrinOp, class MaskNeutral>
699static Value
700lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
701 Location loc, Type llvmType,
702 Value vectorOperand, Value accumulator,
703 Value mask, LLVM::FastmathFlagsAttr fmf) {
704 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
705 rewriter, loc, llvmType, vectorOperand.getType());
706 const Value selectedVectorByMask = LLVM::SelectOp::create(
707 rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
708 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
709 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
710}
711
712template <class LLVMRedIntrinOp, class ReductionNeutral>
713static Value
714lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
715 Type llvmType, Value vectorOperand,
716 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
717 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
718 llvmType, accumulator);
719 return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
720 /*start_value=*/accumulator, vectorOperand,
721 fmf);
722}
723
724/// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
725/// that requires a start value. This start value format spans across fp
726/// reductions without mask and all the masked reduction intrinsics.
727template <class LLVMVPRedIntrinOp, class ReductionNeutral>
728static Value
729lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
730 Location loc, Type llvmType,
731 Value vectorOperand, Value accumulator) {
732 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
733 llvmType, accumulator);
734 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
735 /*startValue=*/accumulator, vectorOperand);
736}
737
738template <class LLVMVPRedIntrinOp, class ReductionNeutral>
739static Value lowerPredicatedReductionWithStartValue(
740 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
741 Value vectorOperand, Value accumulator, Value mask) {
742 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
743 llvmType, accumulator);
744 Value vectorLength =
745 createVectorLengthValue(rewriter, loc, vectorOperand.getType());
746 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
747 /*satrt_value=*/accumulator, vectorOperand,
748 mask, vectorLength);
749}
750
751template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
752 class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
753static Value lowerPredicatedReductionWithStartValue(
754 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
755 Value vectorOperand, Value accumulator, Value mask) {
756 if (llvmType.isIntOrIndex())
757 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
758 IntReductionNeutral>(
759 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
760
761 // FP dispatch.
762 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
763 FPReductionNeutral>(
764 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
765}
766
767/// Conversion pattern for all vector reductions.
768class VectorReductionOpConversion
769 : public ConvertOpToLLVMPattern<vector::ReductionOp> {
770public:
771 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
772 bool reassociateFPRed)
773 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
774 reassociateFPReductions(reassociateFPRed) {}
775
776 LogicalResult
777 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
778 ConversionPatternRewriter &rewriter) const override {
779 auto kind = reductionOp.getKind();
780 Type eltType = reductionOp.getDest().getType();
781 Type llvmType = typeConverter->convertType(eltType);
782 Value operand = adaptor.getVector();
783 Value acc = adaptor.getAcc();
784 Location loc = reductionOp.getLoc();
785
786 if (eltType.isIntOrIndex()) {
787 // Integer reductions: add/mul/min/max/and/or/xor.
788 Value result;
789 switch (kind) {
790 case vector::CombiningKind::ADD:
791 result =
792 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
793 LLVM::AddOp>(
794 rewriter, loc, llvmType, operand, acc);
795 break;
796 case vector::CombiningKind::MUL:
797 result =
798 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
799 LLVM::MulOp>(
800 rewriter, loc, llvmType, operand, acc);
801 break;
802 case vector::CombiningKind::MINUI:
803 result = createIntegerReductionComparisonOpLowering<
804 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
805 LLVM::ICmpPredicate::ule);
806 break;
807 case vector::CombiningKind::MINSI:
808 result = createIntegerReductionComparisonOpLowering<
809 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
810 LLVM::ICmpPredicate::sle);
811 break;
812 case vector::CombiningKind::MAXUI:
813 result = createIntegerReductionComparisonOpLowering<
814 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
815 LLVM::ICmpPredicate::uge);
816 break;
817 case vector::CombiningKind::MAXSI:
818 result = createIntegerReductionComparisonOpLowering<
819 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
820 LLVM::ICmpPredicate::sge);
821 break;
822 case vector::CombiningKind::AND:
823 result =
824 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
825 LLVM::AndOp>(
826 rewriter, loc, llvmType, operand, acc);
827 break;
828 case vector::CombiningKind::OR:
829 result =
830 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
831 LLVM::OrOp>(
832 rewriter, loc, llvmType, operand, acc);
833 break;
834 case vector::CombiningKind::XOR:
835 result =
836 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
837 LLVM::XOrOp>(
838 rewriter, loc, llvmType, operand, acc);
839 break;
840 default:
841 return failure();
842 }
843 rewriter.replaceOp(reductionOp, result);
844
845 return success();
846 }
847
848 if (!isa<FloatType>(eltType))
849 return failure();
850
851 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
852 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
853 reductionOp.getContext(),
854 convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
855 fmf = LLVM::FastmathFlagsAttr::get(
856 reductionOp.getContext(),
857 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
858 : LLVM::FastmathFlags::none));
859
860 // Floating-point reductions: add/mul/min/max
861 Value result;
862 if (kind == vector::CombiningKind::ADD) {
863 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
864 ReductionNeutralZero>(
865 rewriter, loc, llvmType, operand, acc, fmf);
866 } else if (kind == vector::CombiningKind::MUL) {
867 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
868 ReductionNeutralFPOne>(
869 rewriter, loc, llvmType, operand, acc, fmf);
870 } else if (kind == vector::CombiningKind::MINIMUMF) {
871 result =
872 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
873 rewriter, loc, llvmType, operand, acc, fmf);
874 } else if (kind == vector::CombiningKind::MAXIMUMF) {
875 result =
876 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
877 rewriter, loc, llvmType, operand, acc, fmf);
878 } else if (kind == vector::CombiningKind::MINNUMF) {
879 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
880 rewriter, loc, llvmType, operand, acc, fmf);
881 } else if (kind == vector::CombiningKind::MAXNUMF) {
882 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
883 rewriter, loc, llvmType, operand, acc, fmf);
884 } else {
885 return failure();
886 }
887
888 rewriter.replaceOp(reductionOp, result);
889 return success();
890 }
891
892private:
893 const bool reassociateFPReductions;
894};
895
896/// Base class to convert a `vector.mask` operation while matching traits
897/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
898/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
899/// method performs a second match against the maskable operation `MaskedOp`.
900/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
901/// implemented by the concrete conversion classes. This method can match
902/// against specific traits of the `vector.mask` and the maskable operation. It
903/// must replace the `vector.mask` operation.
904template <class MaskedOp>
905class VectorMaskOpConversionBase
906 : public ConvertOpToLLVMPattern<vector::MaskOp> {
907public:
908 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
909
910 LogicalResult
911 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
912 ConversionPatternRewriter &rewriter) const final {
913 // Match against the maskable operation kind.
914 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
915 if (!maskedOp)
916 return failure();
917 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
918 }
919
920protected:
921 virtual LogicalResult
922 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
923 vector::MaskableOpInterface maskableOp,
924 ConversionPatternRewriter &rewriter) const = 0;
925};
926
927class MaskedReductionOpConversion
928 : public VectorMaskOpConversionBase<vector::ReductionOp> {
929
930public:
931 using VectorMaskOpConversionBase<
932 vector::ReductionOp>::VectorMaskOpConversionBase;
933
934 LogicalResult matchAndRewriteMaskableOp(
935 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
936 ConversionPatternRewriter &rewriter) const override {
937 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
938 auto kind = reductionOp.getKind();
939 Type eltType = reductionOp.getDest().getType();
940 Type llvmType = typeConverter->convertType(eltType);
941 Value operand = reductionOp.getVector();
942 Value acc = reductionOp.getAcc();
943 Location loc = reductionOp.getLoc();
944
945 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
946 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
947 reductionOp.getContext(),
948 convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
949
950 Value result;
951 switch (kind) {
952 case vector::CombiningKind::ADD:
953 result = lowerPredicatedReductionWithStartValue<
954 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
955 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
956 maskOp.getMask());
957 break;
958 case vector::CombiningKind::MUL:
959 result = lowerPredicatedReductionWithStartValue<
960 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
961 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
962 maskOp.getMask());
963 break;
964 case vector::CombiningKind::MINUI:
965 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
966 ReductionNeutralUIntMax>(
967 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
968 break;
969 case vector::CombiningKind::MINSI:
970 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
971 ReductionNeutralSIntMax>(
972 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
973 break;
974 case vector::CombiningKind::MAXUI:
975 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
976 ReductionNeutralUIntMin>(
977 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
978 break;
979 case vector::CombiningKind::MAXSI:
980 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
981 ReductionNeutralSIntMin>(
982 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
983 break;
984 case vector::CombiningKind::AND:
985 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
986 ReductionNeutralAllOnes>(
987 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
988 break;
989 case vector::CombiningKind::OR:
990 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
991 ReductionNeutralZero>(
992 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
993 break;
994 case vector::CombiningKind::XOR:
995 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
996 ReductionNeutralZero>(
997 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
998 break;
999 case vector::CombiningKind::MINNUMF:
1000 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1001 ReductionNeutralFPMax>(
1002 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1003 break;
1004 case vector::CombiningKind::MAXNUMF:
1005 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1006 ReductionNeutralFPMin>(
1007 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1008 break;
1009 case CombiningKind::MAXIMUMF:
1010 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1011 MaskNeutralFMaximum>(
1012 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1013 break;
1014 case CombiningKind::MINIMUMF:
1015 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1016 MaskNeutralFMinimum>(
1017 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1018 break;
1019 }
1020
1021 // Replace `vector.mask` operation altogether.
1022 rewriter.replaceOp(maskOp, result);
1023 return success();
1024 }
1025};
1026
1027class VectorShuffleOpConversion
1028 : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
1029public:
1030 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
1031
1032 LogicalResult
1033 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1034 ConversionPatternRewriter &rewriter) const override {
1035 auto loc = shuffleOp->getLoc();
1036 auto v1Type = shuffleOp.getV1VectorType();
1037 auto v2Type = shuffleOp.getV2VectorType();
1038 auto vectorType = shuffleOp.getResultVectorType();
1039 Type llvmType = typeConverter->convertType(vectorType);
1040 ArrayRef<int64_t> mask = shuffleOp.getMask();
1041
1042 // Bail if result type cannot be lowered.
1043 if (!llvmType)
1044 return failure();
1045
1046 // Get rank and dimension sizes.
1047 int64_t rank = vectorType.getRank();
1048#ifndef NDEBUG
1049 bool wellFormed0DCase =
1050 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1051 bool wellFormedNDCase =
1052 v1Type.getRank() == rank && v2Type.getRank() == rank;
1053 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1054#endif
1055
1056 // For rank 0 and 1, where both operands have *exactly* the same vector
1057 // type, there is direct shuffle support in LLVM. Use it!
1058 if (rank <= 1 && v1Type == v2Type) {
1059 Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
1060 rewriter, loc, adaptor.getV1(), adaptor.getV2(),
1061 llvm::to_vector_of<int32_t>(mask));
1062 rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1063 return success();
1064 }
1065
1066 // For all other cases, insert the individual values individually.
1067 int64_t v1Dim = v1Type.getDimSize(0);
1068 Type eltType;
1069 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1070 eltType = arrayType.getElementType();
1071 else
1072 eltType = cast<VectorType>(llvmType).getElementType();
1073 Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1074 int64_t insPos = 0;
1075 for (int64_t extPos : mask) {
1076 Value value = adaptor.getV1();
1077 if (extPos >= v1Dim) {
1078 extPos -= v1Dim;
1079 value = adaptor.getV2();
1080 }
1081 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1082 eltType, rank, extPos);
1083 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1084 llvmType, rank, insPos++);
1085 }
1086 rewriter.replaceOp(shuffleOp, insert);
1087 return success();
1088 }
1089};
1090
1091class VectorExtractOpConversion
1092 : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1093public:
1094 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1095
1096 LogicalResult
1097 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1098 ConversionPatternRewriter &rewriter) const override {
1099 auto loc = extractOp->getLoc();
1100 auto resultType = extractOp.getResult().getType();
1101 auto llvmResultType = typeConverter->convertType(resultType);
1102 // Bail if result type cannot be lowered.
1103 if (!llvmResultType)
1104 return failure();
1105
1106 SmallVector<OpFoldResult> positionVec = getMixedValues(
1107 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1108
1109 // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1110 // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1111 // from a N-d vector extract to a nested aggregate vector extract in two
1112 // steps:
1113 // - Extract a member from the nested aggregate. The result can be
1114 // a lower rank nested aggregate or a vector (1-D). This is done using
1115 // `llvm.extractvalue`.
1116 // - Extract a scalar out of the vector if needed. This is done using
1117 // `llvm.extractelement`.
1118
1119 // Determine if we need to extract a member out of the aggregate. We
1120 // always need to extract a member if the input rank >= 2.
1121 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1122 // Determine if we need to extract a scalar as the result. We extract
1123 // a scalar if the extract is full rank, i.e., the number of indices is
1124 // equal to source vector rank.
1125 bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
1126 extractOp.getSourceVectorType().getRank();
1127
1128 // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1129 // need to add a position for this change.
1130 if (extractOp.getSourceVectorType().getRank() == 0) {
1131 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1132 positionVec.push_back(rewriter.getZeroAttr(idxType));
1133 }
1134
1135 Value extracted = adaptor.getSource();
1136 if (extractsAggregate) {
1137 ArrayRef<OpFoldResult> position(positionVec);
1138 if (extractsScalar) {
1139 // If we are extracting a scalar from the extracted member, we drop
1140 // the last index, which will be used to extract the scalar out of the
1141 // vector.
1142 position = position.drop_back();
1143 }
1144 // llvm.extractvalue does not support dynamic dimensions.
1145 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1146 return failure();
1147 }
1148 extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
1149 getAsIntegers(position));
1150 }
1151
1152 if (extractsScalar) {
1153 extracted = LLVM::ExtractElementOp::create(
1154 rewriter, loc, extracted,
1155 getAsLLVMValue(rewriter, loc, positionVec.back()));
1156 }
1157
1158 rewriter.replaceOp(extractOp, extracted);
1159 return success();
1160 }
1161};
1162
1163/// Conversion pattern that turns a vector.fma on a 1-D vector
1164/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1165/// This does not match vectors of n >= 2 rank.
1166///
1167/// Example:
1168/// ```
1169/// vector.fma %a, %a, %a : vector<8xf32>
1170/// ```
1171/// is converted to:
1172/// ```
1173/// llvm.intr.fmuladd %va, %va, %va:
1174/// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1175/// -> !llvm."<8 x f32>">
1176/// ```
1177class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1178public:
1179 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1180
1181 LogicalResult
1182 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1183 ConversionPatternRewriter &rewriter) const override {
1184 VectorType vType = fmaOp.getVectorType();
1185 if (vType.getRank() > 1)
1186 return failure();
1187
1188 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1189 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1190 return success();
1191 }
1192};
1193
1194class VectorInsertOpConversion
1195 : public ConvertOpToLLVMPattern<vector::InsertOp> {
1196public:
1197 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1198
1199 LogicalResult
1200 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1201 ConversionPatternRewriter &rewriter) const override {
1202 auto loc = insertOp->getLoc();
1203 auto destVectorType = insertOp.getDestVectorType();
1204 auto llvmResultType = typeConverter->convertType(destVectorType);
1205 // Bail if result type cannot be lowered.
1206 if (!llvmResultType)
1207 return failure();
1208
1209 SmallVector<OpFoldResult> positionVec = getMixedValues(
1210 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1211
1212 // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1213 // its explanatory comment about how N-D vectors are converted as nested
1214 // aggregates (llvm.array's) of 1D vectors.
1215 //
1216 // The innermost dimension of the destination vector, when converted to a
1217 // nested aggregate form, will always be a 1D vector.
1218 //
1219 // * If the insertion is happening into the innermost dimension of the
1220 // destination vector:
1221 // - If the destination is a nested aggregate, extract a 1D vector out of
1222 // the aggregate. This can be done using llvm.extractvalue. The
1223 // destination is now guaranteed to be a 1D vector, to which we are
1224 // inserting.
1225 // - Do the insertion into the 1D destination vector, and make the result
1226 // the new source nested aggregate. This can be done using
1227 // llvm.insertelement.
1228 // * Insert the source nested aggregate into the destination nested
1229 // aggregate.
1230
1231 // Determine if we need to extract/insert a 1D vector out of the aggregate.
1232 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1233 // Determine if we need to insert a scalar into the 1D vector.
1234 bool insertIntoInnermostDim =
1235 static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
1236
1237 ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1238 positionVec.begin(),
1239 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1240 OpFoldResult positionOfScalarWithin1DVector;
1241 if (destVectorType.getRank() == 0) {
1242 // Since the LLVM type converter converts 0D vectors to 1D vectors, we
1243 // need to create a 0 here as the position into the 1D vector.
1244 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1245 positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1246 } else if (insertIntoInnermostDim) {
1247 positionOfScalarWithin1DVector = positionVec.back();
1248 }
1249
1250 // We are going to mutate this 1D vector until it is either the final
1251 // result (in the non-aggregate case) or the value that needs to be
1252 // inserted into the aggregate result.
1253 Value sourceAggregate = adaptor.getValueToStore();
1254 if (insertIntoInnermostDim) {
1255 // Scalar-into-1D-vector case, so we know we will have to create a
1256 // InsertElementOp. The question is into what destination.
1257 if (isNestedAggregate) {
1258 // Aggregate case: the destination for the InsertElementOp needs to be
1259 // extracted from the aggregate.
1260 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1261 llvm::IsaPred<Attribute>)) {
1262 // llvm.extractvalue does not support dynamic dimensions.
1263 return failure();
1264 }
1265 sourceAggregate = LLVM::ExtractValueOp::create(
1266 rewriter, loc, adaptor.getDest(),
1267 getAsIntegers(positionOf1DVectorWithinAggregate));
1268 } else {
1269 // No-aggregate case. The destination for the InsertElementOp is just
1270 // the insertOp's destination.
1271 sourceAggregate = adaptor.getDest();
1272 }
1273 // Insert the scalar into the 1D vector.
1274 sourceAggregate = LLVM::InsertElementOp::create(
1275 rewriter, loc, sourceAggregate.getType(), sourceAggregate,
1276 adaptor.getValueToStore(),
1277 getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
1278 }
1279
1280 Value result = sourceAggregate;
1281 if (isNestedAggregate) {
1282 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1283 llvm::IsaPred<Attribute>)) {
1284 // llvm.insertvalue does not support dynamic dimensions.
1285 return failure();
1286 }
1287 result = LLVM::InsertValueOp::create(
1288 rewriter, loc, adaptor.getDest(), sourceAggregate,
1289 getAsIntegers(positionOf1DVectorWithinAggregate));
1290 }
1291
1292 rewriter.replaceOp(insertOp, result);
1293 return success();
1294 }
1295};
1296
1297/// Lower vector.scalable.insert ops to LLVM vector.insert
1298struct VectorScalableInsertOpLowering
1299 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1300 using ConvertOpToLLVMPattern<
1301 vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1302
1303 LogicalResult
1304 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1305 ConversionPatternRewriter &rewriter) const override {
1306 rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1307 insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1308 return success();
1309 }
1310};
1311
1312/// Lower vector.scalable.extract ops to LLVM vector.extract
1313struct VectorScalableExtractOpLowering
1314 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1315 using ConvertOpToLLVMPattern<
1316 vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1317
1318 LogicalResult
1319 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1320 ConversionPatternRewriter &rewriter) const override {
1321 rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1322 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1323 adaptor.getSource(), adaptor.getPos());
1324 return success();
1325 }
1326};
1327
1328/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1329///
1330/// Example:
1331/// ```
1332/// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1333/// ```
1334/// is rewritten into:
1335/// ```
1336/// %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
1337/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1338/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1339/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1340/// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1341/// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1342/// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1343/// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1344/// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1345/// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1346/// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1347/// // %r3 holds the final value.
1348/// ```
1349class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1350public:
1351 using Base::Base;
1352
1353 void initialize() {
1354 // This pattern recursively unpacks one dimension at a time. The recursion
1355 // bounded as the rank is strictly decreasing.
1356 setHasBoundedRewriteRecursion();
1357 }
1358
1359 LogicalResult matchAndRewrite(FMAOp op,
1360 PatternRewriter &rewriter) const override {
1361 auto vType = op.getVectorType();
1362 if (vType.getRank() < 2)
1363 return failure();
1364
1365 auto loc = op.getLoc();
1366 auto elemType = vType.getElementType();
1367 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
1368 rewriter.getZeroAttr(elemType));
1369 Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
1370 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1371 Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
1372 Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
1373 Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
1374 Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
1375 desc = InsertOp::create(rewriter, loc, fma, desc, i);
1376 }
1377 rewriter.replaceOp(op, desc);
1378 return success();
1379 }
1380};
1381
1382/// Returns the strides if the memory underlying `memRefType` has a contiguous
1383/// static layout.
1384static std::optional<SmallVector<int64_t, 4>>
1385computeContiguousStrides(MemRefType memRefType) {
1386 int64_t offset;
1388 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1389 return std::nullopt;
1390 if (!strides.empty() && strides.back() != 1)
1391 return std::nullopt;
1392 // If no layout or identity layout, this is contiguous by definition.
1393 if (memRefType.getLayout().isIdentity())
1394 return strides;
1395
1396 // Otherwise, we must determine contiguity form shapes. This can only ever
1397 // work in static cases because MemRefType is underspecified to represent
1398 // contiguous dynamic shapes in other ways than with just empty/identity
1399 // layout.
1400 auto sizes = memRefType.getShape();
1401 for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1402 if (ShapedType::isDynamic(sizes[index + 1]) ||
1403 ShapedType::isDynamic(strides[index]) ||
1404 ShapedType::isDynamic(strides[index + 1]))
1405 return std::nullopt;
1406 if (strides[index] != strides[index + 1] * sizes[index + 1])
1407 return std::nullopt;
1408 }
1409 return strides;
1410}
1411
1412class VectorTypeCastOpConversion
1413 : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1414public:
1415 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1416
1417 LogicalResult
1418 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1419 ConversionPatternRewriter &rewriter) const override {
1420 auto loc = castOp->getLoc();
1421 MemRefType sourceMemRefType =
1422 cast<MemRefType>(castOp.getOperand().getType());
1423 MemRefType targetMemRefType = castOp.getType();
1424
1425 // Only static shape casts supported atm.
1426 if (!sourceMemRefType.hasStaticShape() ||
1427 !targetMemRefType.hasStaticShape())
1428 return failure();
1429
1430 auto llvmSourceDescriptorTy =
1431 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1432 if (!llvmSourceDescriptorTy)
1433 return failure();
1434 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1435
1436 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1437 typeConverter->convertType(targetMemRefType));
1438 if (!llvmTargetDescriptorTy)
1439 return failure();
1440
1441 // Only contiguous source buffers supported atm.
1442 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1443 if (!sourceStrides)
1444 return failure();
1445 auto targetStrides = computeContiguousStrides(targetMemRefType);
1446 if (!targetStrides)
1447 return failure();
1448 // Only support static strides for now, regardless of contiguity.
1449 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1450 return failure();
1451
1452 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1453
1454 // Create descriptor.
1455 auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1456 // Set allocated ptr.
1457 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1458 desc.setAllocatedPtr(rewriter, loc, allocated);
1459
1460 // Set aligned ptr.
1461 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1462 desc.setAlignedPtr(rewriter, loc, ptr);
1463 // Fill offset 0.
1464 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1465 auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
1466 desc.setOffset(rewriter, loc, zero);
1467
1468 // Fill size and stride descriptors in memref.
1469 for (const auto &indexedSize :
1470 llvm::enumerate(targetMemRefType.getShape())) {
1471 int64_t index = indexedSize.index();
1472 auto sizeAttr =
1473 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1474 auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
1475 desc.setSize(rewriter, loc, index, size);
1476 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1477 (*targetStrides)[index]);
1478 auto stride =
1479 LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
1480 desc.setStride(rewriter, loc, index, stride);
1481 }
1482
1483 rewriter.replaceOp(castOp, {desc});
1484 return success();
1485 }
1486};
1487
1488/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1489/// Non-scalable versions of this operation are handled in Vector Transforms.
1490class VectorCreateMaskOpConversion
1491 : public OpConversionPattern<vector::CreateMaskOp> {
1492public:
1493 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1494 bool enableIndexOpt)
1495 : OpConversionPattern<vector::CreateMaskOp>(context),
1496 force32BitVectorIndices(enableIndexOpt) {}
1497
1498 LogicalResult
1499 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1500 ConversionPatternRewriter &rewriter) const override {
1501 auto dstType = op.getType();
1502 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1503 return failure();
1504 IntegerType idxType =
1505 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1506 auto loc = op->getLoc();
1507 Value indices = LLVM::StepVectorOp::create(
1508 rewriter, loc,
1509 LLVM::getVectorType(idxType, dstType.getShape()[0],
1510 /*isScalable=*/true));
1511 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1512 adaptor.getOperands()[0]);
1513 Value bounds = BroadcastOp::create(rewriter, loc, indices.getType(), bound);
1514 Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1515 indices, bounds);
1516 rewriter.replaceOp(op, comp);
1517 return success();
1518 }
1519
1520private:
1521 const bool force32BitVectorIndices;
1522};
1523
1524class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1525 SymbolTableCollection *symbolTables = nullptr;
1526
1527public:
1528 explicit VectorPrintOpConversion(
1529 const LLVMTypeConverter &typeConverter,
1530 SymbolTableCollection *symbolTables = nullptr)
1531 : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
1532 symbolTables(symbolTables) {}
1533
1534 // Lowering implementation that relies on a small runtime support library,
1535 // which only needs to provide a few printing methods (single value for all
1536 // data types, opening/closing bracket, comma, newline). The lowering splits
1537 // the vector into elementary printing operations. The advantage of this
1538 // approach is that the library can remain unaware of all low-level
1539 // implementation details of vectors while still supporting output of any
1540 // shaped and dimensioned vector.
1541 //
1542 // Note: This lowering only handles scalars, n-D vectors are broken into
1543 // printing scalars in loops in VectorToSCF.
1544 //
1545 // TODO: rely solely on libc in future? something else?
1546 //
1547 LogicalResult
1548 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1549 ConversionPatternRewriter &rewriter) const override {
1550 auto parent = printOp->getParentOfType<ModuleOp>();
1551 if (!parent)
1552 return failure();
1553
1554 auto loc = printOp->getLoc();
1555
1556 if (auto value = adaptor.getSource()) {
1557 Type printType = printOp.getPrintType();
1558 if (isa<VectorType>(printType)) {
1559 // Vectors should be broken into elementary print ops in VectorToSCF.
1560 return failure();
1561 }
1562 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1563 return failure();
1564 }
1565
1566 auto punct = printOp.getPunctuation();
1567 if (auto stringLiteral = printOp.getStringLiteral()) {
1568 auto createResult =
1569 LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1570 *stringLiteral, *getTypeConverter(),
1571 /*addNewline=*/false);
1572 if (createResult.failed())
1573 return failure();
1574
1575 } else if (punct != PrintPunctuation::NoPunctuation) {
1576 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1577 switch (punct) {
1578 case PrintPunctuation::Close:
1579 return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent,
1580 symbolTables);
1581 case PrintPunctuation::Open:
1582 return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent,
1583 symbolTables);
1584 case PrintPunctuation::Comma:
1585 return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent,
1586 symbolTables);
1587 case PrintPunctuation::NewLine:
1588 return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent,
1589 symbolTables);
1590 default:
1591 llvm_unreachable("unexpected punctuation");
1592 }
1593 }();
1594 if (failed(op))
1595 return failure();
1596 emitCall(rewriter, printOp->getLoc(), op.value());
1597 }
1598
1599 rewriter.eraseOp(printOp);
1600 return success();
1601 }
1602
1603private:
1604 enum class PrintConversion {
1605 // clang-format off
1606 None,
1607 ZeroExt64,
1608 SignExt64,
1609 Bitcast16
1610 // clang-format on
1611 };
1612
1613 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1614 ModuleOp parent, Location loc, Type printType,
1615 Value value) const {
1616 if (typeConverter->convertType(printType) == nullptr)
1617 return failure();
1618
1619 // Make sure element type has runtime support.
1620 PrintConversion conversion = PrintConversion::None;
1621 FailureOr<Operation *> printer;
1622 if (printType.isF32()) {
1623 printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables);
1624 } else if (printType.isF64()) {
1625 printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables);
1626 } else if (printType.isF16()) {
1627 conversion = PrintConversion::Bitcast16; // bits!
1628 printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables);
1629 } else if (printType.isBF16()) {
1630 conversion = PrintConversion::Bitcast16; // bits!
1631 printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables);
1632 } else if (printType.isIndex()) {
1633 printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1634 } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1635 // Integers need a zero or sign extension on the operand
1636 // (depending on the source type) as well as a signed or
1637 // unsigned print method. Up to 64-bit is supported.
1638 unsigned width = intTy.getWidth();
1639 if (intTy.isUnsigned()) {
1640 if (width <= 64) {
1641 if (width < 64)
1642 conversion = PrintConversion::ZeroExt64;
1643 printer =
1644 LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1645 } else {
1646 return failure();
1647 }
1648 } else {
1649 assert(intTy.isSignless() || intTy.isSigned());
1650 if (width <= 64) {
1651 // Note that we *always* zero extend booleans (1-bit integers),
1652 // so that true/false is printed as 1/0 rather than -1/0.
1653 if (width == 1)
1654 conversion = PrintConversion::ZeroExt64;
1655 else if (width < 64)
1656 conversion = PrintConversion::SignExt64;
1657 printer =
1658 LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables);
1659 } else {
1660 return failure();
1661 }
1662 }
1663 } else if (auto floatTy = dyn_cast<FloatType>(printType)) {
1664 // Print other floating-point types using the APFloat runtime library.
1665 int32_t sem =
1666 llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
1667 Value semValue = LLVM::ConstantOp::create(
1668 rewriter, loc, rewriter.getI32Type(),
1669 rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
1670 Value floatBits =
1671 LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
1672 printer =
1673 LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
1674 emitCall(rewriter, loc, printer.value(),
1675 ValueRange({semValue, floatBits}));
1676 return success();
1677 } else {
1678 return failure();
1679 }
1680 if (failed(printer))
1681 return failure();
1682
1683 switch (conversion) {
1684 case PrintConversion::ZeroExt64:
1685 value = arith::ExtUIOp::create(
1686 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1687 break;
1688 case PrintConversion::SignExt64:
1689 value = arith::ExtSIOp::create(
1690 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1691 break;
1692 case PrintConversion::Bitcast16:
1693 value = LLVM::BitcastOp::create(
1694 rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value);
1695 break;
1696 case PrintConversion::None:
1697 break;
1698 }
1699 emitCall(rewriter, loc, printer.value(), value);
1700 return success();
1701 }
1702
1703 // Helper to emit a call.
1704 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1705 Operation *ref, ValueRange params = ValueRange()) {
1706 LLVM::CallOp::create(rewriter, loc, TypeRange(), SymbolRefAttr::get(ref),
1707 params);
1708 }
1709};
1710
1711/// A broadcast of a scalar is lowered to an insertelement + a shufflevector
1712/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
1713/// pattern, the higher rank cases are handled by another pattern.
1714struct VectorBroadcastScalarToLowRankLowering
1715 : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
1716 using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
1717
1718 LogicalResult
1719 matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
1720 ConversionPatternRewriter &rewriter) const override {
1721 if (isa<VectorType>(broadcast.getSourceType()))
1722 return rewriter.notifyMatchFailure(
1723 broadcast, "broadcast from vector type not handled");
1724
1725 VectorType resultType = broadcast.getType();
1726 if (resultType.getRank() > 1)
1727 return rewriter.notifyMatchFailure(broadcast,
1728 "broadcast to 2+-d handled elsewhere");
1729
1730 // First insert it into a poison vector so we can shuffle it.
1731 auto vectorType = typeConverter->convertType(broadcast.getType());
1732 Value poison =
1733 LLVM::PoisonOp::create(rewriter, broadcast.getLoc(), vectorType);
1734 auto zero = LLVM::ConstantOp::create(
1735 rewriter, broadcast.getLoc(),
1736 typeConverter->convertType(rewriter.getIntegerType(32)),
1737 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1738
1739 // For 0-d vector, we simply do `insertelement`.
1740 if (resultType.getRank() == 0) {
1741 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1742 broadcast, vectorType, poison, adaptor.getSource(), zero);
1743 return success();
1744 }
1745
1746 auto v =
1747 LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
1748 poison, adaptor.getSource(), zero);
1749
1750 // For 1-d vector, we additionally do a `shufflevector`.
1751 int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
1752 SmallVector<int32_t> zeroValues(width, 0);
1753
1754 // Shuffle the value across the desired number of elements.
1755 auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
1756 broadcast.getLoc(), v, poison, zeroValues);
1757 rewriter.replaceOp(broadcast, shuffle);
1758 return success();
1759 }
1760};
1761
1762/// The broadcast of a scalar is lowered to an insertelement + a shufflevector
1763/// operation. Only broadcasts to 2+-d vector result types are lowered by this
1764/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
1765/// are not converted to LLVM, only broadcasts from scalars are.
1766struct VectorBroadcastScalarToNdLowering
1767 : public ConvertOpToLLVMPattern<BroadcastOp> {
1768 using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
1769
1770 LogicalResult
1771 matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
1772 ConversionPatternRewriter &rewriter) const override {
1773 if (isa<VectorType>(broadcast.getSourceType()))
1774 return rewriter.notifyMatchFailure(
1775 broadcast, "broadcast from vector type not handled");
1776
1777 VectorType resultType = broadcast.getType();
1778 if (resultType.getRank() <= 1)
1779 return rewriter.notifyMatchFailure(
1780 broadcast, "broadcast to 1-d or 0-d handled elsewhere");
1781
1782 // First insert it into an undef vector so we can shuffle it.
1783 auto loc = broadcast.getLoc();
1784 auto vectorTypeInfo =
1785 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1786 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1787 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1788 if (!llvmNDVectorTy || !llvm1DVectorTy)
1789 return failure();
1790
1791 // Construct returned value.
1792 Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
1793
1794 // Construct a 1-D vector with the broadcasted value that we insert in all
1795 // the places within the returned descriptor.
1796 Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
1797 auto zero = LLVM::ConstantOp::create(
1798 rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1799 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1800 Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
1801 vdesc, adaptor.getSource(), zero);
1802
1803 // Shuffle the value across the desired number of elements.
1804 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1805 SmallVector<int32_t> zeroValues(width, 0);
1806 v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
1807
1808 // Iterate of linear index, convert to coords space and insert broadcasted
1809 // 1-D vector in each position.
1810 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1811 desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
1812 });
1813 rewriter.replaceOp(broadcast, desc);
1814 return success();
1815 }
1816};
1817
1818/// Conversion pattern for a `vector.interleave`.
1819/// This supports fixed-sized vectors and scalable vectors.
1820struct VectorInterleaveOpLowering
1821 : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1823
1824 LogicalResult
1825 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1826 ConversionPatternRewriter &rewriter) const override {
1827 VectorType resultType = interleaveOp.getResultVectorType();
1828 // n-D interleaves should have been lowered already.
1829 if (resultType.getRank() != 1)
1830 return rewriter.notifyMatchFailure(interleaveOp,
1831 "InterleaveOp not rank 1");
1832 // If the result is rank 1, then this directly maps to LLVM.
1833 if (resultType.isScalable()) {
1834 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1835 interleaveOp, typeConverter->convertType(resultType),
1836 adaptor.getLhs(), adaptor.getRhs());
1837 return success();
1838 }
1839 // Lower fixed-size interleaves to a shufflevector. While the
1840 // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1841 // langref still recommends fixed-vectors use shufflevector, see:
1842 // https://llvm.org/docs/LangRef.html#id876.
1843 int64_t resultVectorSize = resultType.getNumElements();
1844 SmallVector<int32_t> interleaveShuffleMask;
1845 interleaveShuffleMask.reserve(resultVectorSize);
1846 for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1847 interleaveShuffleMask.push_back(i);
1848 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1849 }
1850 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1851 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1852 interleaveShuffleMask);
1853 return success();
1854 }
1855};
1856
1857/// Conversion pattern for a `vector.deinterleave`.
1858/// This supports fixed-sized vectors and scalable vectors.
1859struct VectorDeinterleaveOpLowering
1860 : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1862
1863 LogicalResult
1864 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1865 ConversionPatternRewriter &rewriter) const override {
1866 VectorType resultType = deinterleaveOp.getResultVectorType();
1867 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1868 auto loc = deinterleaveOp.getLoc();
1869
1870 // Note: n-D deinterleave operations should be lowered to the 1-D before
1871 // converting to LLVM.
1872 if (resultType.getRank() != 1)
1873 return rewriter.notifyMatchFailure(deinterleaveOp,
1874 "DeinterleaveOp not rank 1");
1875
1876 if (resultType.isScalable()) {
1877 const auto *llvmTypeConverter = this->getTypeConverter();
1878 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1879 auto packedOpResults =
1880 llvmTypeConverter->packOperationResults(deinterleaveResults);
1881 auto intrinsic = LLVM::vector_deinterleave2::create(
1882 rewriter, loc, packedOpResults, adaptor.getSource());
1883
1884 auto evenResult = LLVM::ExtractValueOp::create(
1885 rewriter, loc, intrinsic->getResult(0), 0);
1886 auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
1887 intrinsic->getResult(0), 1);
1888
1889 rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
1890 return success();
1891 }
1892 // Lower fixed-size deinterleave to two shufflevectors. While the
1893 // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1894 // langref still recommends fixed-vectors use shufflevector, see:
1895 // https://llvm.org/docs/LangRef.html#id889.
1896 int64_t resultVectorSize = resultType.getNumElements();
1897 SmallVector<int32_t> evenShuffleMask;
1898 SmallVector<int32_t> oddShuffleMask;
1899
1900 evenShuffleMask.reserve(resultVectorSize);
1901 oddShuffleMask.reserve(resultVectorSize);
1902
1903 for (int i = 0; i < sourceType.getNumElements(); ++i) {
1904 if (i % 2 == 0)
1905 evenShuffleMask.push_back(i);
1906 else
1907 oddShuffleMask.push_back(i);
1908 }
1909
1910 auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
1911 auto evenShuffle = LLVM::ShuffleVectorOp::create(
1912 rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
1913 auto oddShuffle = LLVM::ShuffleVectorOp::create(
1914 rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
1915
1916 rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
1917 return success();
1918 }
1919};
1920
1921/// Conversion pattern for a `vector.from_elements`.
1922struct VectorFromElementsLowering
1923 : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
1925
1926 LogicalResult
1927 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1928 ConversionPatternRewriter &rewriter) const override {
1929 Location loc = fromElementsOp.getLoc();
1930 VectorType vectorType = fromElementsOp.getType();
1931 // Only support 1-D vectors. Multi-dimensional vectors should have been
1932 // transformed to 1-D vectors by the vector-to-vector transformations before
1933 // this.
1934 if (vectorType.getRank() > 1)
1935 return rewriter.notifyMatchFailure(fromElementsOp,
1936 "rank > 1 vectors are not supported");
1937 Type llvmType = typeConverter->convertType(vectorType);
1938 Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1939 Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1940 for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
1941 auto constIdx =
1942 LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
1943 result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
1944 val, constIdx);
1945 }
1946 rewriter.replaceOp(fromElementsOp, result);
1947 return success();
1948 }
1949};
1950
1951/// Conversion pattern for a `vector.to_elements`.
1952struct VectorToElementsLowering
1953 : public ConvertOpToLLVMPattern<vector::ToElementsOp> {
1955
1956 LogicalResult
1957 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1958 ConversionPatternRewriter &rewriter) const override {
1959 Location loc = toElementsOp.getLoc();
1960 auto idxType = typeConverter->convertType(rewriter.getIndexType());
1961 Value source = adaptor.getSource();
1962
1963 SmallVector<Value> results(toElementsOp->getNumResults());
1964 for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1965 // Create an extractelement operation only for results that are not dead.
1966 if (element.use_empty())
1967 continue;
1968
1969 auto constIdx = LLVM::ConstantOp::create(
1970 rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx));
1971 auto llvmType = typeConverter->convertType(element.getType());
1972
1973 Value result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
1974 source, constIdx);
1975 results[idx] = result;
1976 }
1977
1978 rewriter.replaceOp(toElementsOp, results);
1979 return success();
1980 }
1981};
1982
1983/// Conversion pattern for vector.step.
1984struct VectorScalableStepOpLowering
1985 : public ConvertOpToLLVMPattern<vector::StepOp> {
1987
1988 LogicalResult
1989 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1990 ConversionPatternRewriter &rewriter) const override {
1991 auto resultType = cast<VectorType>(stepOp.getType());
1992 if (!resultType.isScalable()) {
1993 return failure();
1994 }
1995 Type llvmType = typeConverter->convertType(stepOp.getType());
1996 rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
1997 return success();
1998 }
1999};
2000
2001/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
2002/// semantics to:
2003/// ```
2004/// %flattened_a = vector.shape_cast %a
2005/// %flattened_b = vector.shape_cast %b
2006/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
2007/// %d = vector.shape_cast %%flattened_d
2008/// %e = add %c, %d
2009/// ```
2010/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
2011class ContractionOpToMatmulOpLowering
2012 : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
2013public:
2014 using MaskableOpRewritePattern::MaskableOpRewritePattern;
2015
2016 ContractionOpToMatmulOpLowering(MLIRContext *context,
2017 PatternBenefit benefit = 100)
2018 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
2019
2020 FailureOr<Value>
2021 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2022 PatternRewriter &rewriter) const override;
2023};
2024
2025/// Lower a qualifying `vector.contract %a, %b, %c` (with row-major matmul
2026/// semantics directly into `llvm.intr.matrix.multiply`:
2027/// BEFORE:
2028/// ```mlir
2029/// %res = vector.contract #matmat_trait %lhs, %rhs, %acc
2030/// : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
2031/// ```
2032///
2033/// AFTER:
2034/// ```mlir
2035/// %lhs = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
2036/// %rhs = vector.shape_cast %arg1 : vector<4x3xf32> to vector<12xf32>
2037/// %matmul = llvm.intr.matrix.multiply %lhs, %rhs
2038/// %res = arith.addf %acc, %matmul : vector<2x3xf32>
2039/// ```
2040//
2041/// Scalable vectors are not supported.
2042FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
2043 vector::ContractionOp op, MaskingOpInterface maskOp,
2044 PatternRewriter &rew) const {
2045 // TODO: Support vector.mask.
2046 if (maskOp)
2047 return failure();
2048
2049 auto iteratorTypes = op.getIteratorTypes().getValue();
2050 if (!isParallelIterator(iteratorTypes[0]) ||
2051 !isParallelIterator(iteratorTypes[1]) ||
2052 !isReductionIterator(iteratorTypes[2]))
2053 return failure();
2054
2055 Type opResType = op.getType();
2056 VectorType vecType = dyn_cast<VectorType>(opResType);
2057 if (vecType && vecType.isScalable()) {
2058 // Note - this is sufficient to reject all cases with scalable vectors.
2059 return failure();
2060 }
2061
2062 Type elementType = op.getLhsType().getElementType();
2063 if (!elementType.isIntOrFloat())
2064 return failure();
2065
2066 Type dstElementType = vecType ? vecType.getElementType() : opResType;
2067 if (elementType != dstElementType)
2068 return failure();
2069
2070 // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
2071 // Bail out if the contraction cannot be put in this form.
2072 MLIRContext *ctx = op.getContext();
2073 Location loc = op.getLoc();
2074 AffineExpr m, n, k;
2075 bindDims(rew.getContext(), m, n, k);
2076 // LHS must be A(m, k) or A(k, m).
2077 Value lhs = op.getLhs();
2078 auto lhsMap = op.getIndexingMapsArray()[0];
2079 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
2080 lhs = vector::TransposeOp::create(rew, loc, lhs, ArrayRef<int64_t>{1, 0});
2081 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
2082 return failure();
2083
2084 // RHS must be B(k, n) or B(n, k).
2085 Value rhs = op.getRhs();
2086 auto rhsMap = op.getIndexingMapsArray()[1];
2087 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
2088 rhs = vector::TransposeOp::create(rew, loc, rhs, ArrayRef<int64_t>{1, 0});
2089 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
2090 return failure();
2091
2092 // At this point lhs and rhs are in row-major.
2093 VectorType lhsType = cast<VectorType>(lhs.getType());
2094 VectorType rhsType = cast<VectorType>(rhs.getType());
2095 int64_t lhsRows = lhsType.getDimSize(0);
2096 int64_t lhsColumns = lhsType.getDimSize(1);
2097 int64_t rhsColumns = rhsType.getDimSize(1);
2098
2099 Type flattenedLHSType =
2100 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
2101 lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType, lhs);
2102
2103 Type flattenedRHSType =
2104 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
2105 rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType, rhs);
2106
2107 Value mul = LLVM::MatrixMultiplyOp::create(
2108 rew, loc,
2109 VectorType::get(lhsRows * rhsColumns,
2110 cast<VectorType>(lhs.getType()).getElementType()),
2111 lhs, rhs, lhsRows, lhsColumns, rhsColumns);
2112
2113 mul = vector::ShapeCastOp::create(
2114 rew, loc,
2115 VectorType::get({lhsRows, rhsColumns},
2116 getElementTypeOrSelf(op.getAcc().getType())),
2117 mul);
2118
2119 // ACC must be C(m, n) or C(n, m).
2120 auto accMap = op.getIndexingMapsArray()[2];
2121 if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
2122 mul = vector::TransposeOp::create(rew, loc, mul, ArrayRef<int64_t>{1, 0});
2123 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
2124 llvm_unreachable("invalid contraction semantics");
2125
2126 Value res = isa<IntegerType>(elementType)
2127 ? static_cast<Value>(
2128 arith::AddIOp::create(rew, loc, op.getAcc(), mul))
2129 : static_cast<Value>(
2130 arith::AddFOp::create(rew, loc, op.getAcc(), mul));
2131
2132 return res;
2133}
2134
2135/// Lowers vector.transpose directly to llvm.intr.matrix.transpose
2136///
2137/// BEFORE:
2138/// ```mlir
2139/// %tr = vector.transpose %vec, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
2140/// ```
2141/// AFTER:
2142/// ```mlir
2143/// %vec_cs = vector.shape_cast %vec : vector<2x4xf32> to vector<8xf32>
2144/// %tr = llvm.intr.matrix.transpose %vec_sc
2145/// {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> into vector<8xf32>
2146/// %res = vector.shape_cast %tr : vector<8xf32> to vector<4x2xf32>
2147/// ```
2148class TransposeOpToMatrixTransposeOpLowering
2149 : public OpRewritePattern<vector::TransposeOp> {
2150public:
2151 using Base::Base;
2152
2153 LogicalResult matchAndRewrite(vector::TransposeOp op,
2154 PatternRewriter &rewriter) const override {
2155 auto loc = op.getLoc();
2156
2157 Value input = op.getVector();
2158 VectorType inputType = op.getSourceVectorType();
2159 VectorType resType = op.getResultVectorType();
2160
2161 if (inputType.isScalable())
2162 return rewriter.notifyMatchFailure(
2163 op, "This lowering does not support scalable vectors");
2164
2165 // Set up convenience transposition table.
2166 ArrayRef<int64_t> transp = op.getPermutation();
2167
2168 if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
2169 return failure();
2170 }
2171
2172 Type flattenedType =
2173 VectorType::get(resType.getNumElements(), resType.getElementType());
2174 auto matrix =
2175 vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
2176 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
2177 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
2178 Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
2179 matrix, rows, columns);
2180 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
2181 return success();
2182 }
2183};
2184
2185} // namespace
2186
2188 RewritePatternSet &patterns) {
2189 patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
2190}
2191
2193 RewritePatternSet &patterns, PatternBenefit benefit) {
2194 patterns.add<ContractionOpToMatmulOpLowering>(patterns.getContext(), benefit);
2195}
2196
2198 RewritePatternSet &patterns, PatternBenefit benefit) {
2199 patterns.add<TransposeOpToMatrixTransposeOpLowering>(patterns.getContext(),
2200 benefit);
2201}
2202
2203/// Populate the given list with patterns that convert from Vector to LLVM.
2205 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2206 bool reassociateFPReductions, bool force32BitVectorIndices,
2207 bool useVectorAlignment) {
2208 // This function populates only ConversionPatterns, not RewritePatterns.
2209 MLIRContext *ctx = converter.getDialect()->getContext();
2210 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2211 patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2212 patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2213 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2214 VectorLoadStoreConversion<vector::StoreOp>,
2215 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2216 VectorGatherOpConversion, VectorScatterOpConversion>(
2217 converter, useVectorAlignment);
2218 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2219 VectorExtractOpConversion, VectorFMAOp1DConversion,
2220 VectorInsertOpConversion, VectorPrintOpConversion,
2221 VectorTypeCastOpConversion, VectorScaleOpConversion,
2222 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2223 VectorBroadcastScalarToLowRankLowering,
2224 VectorBroadcastScalarToNdLowering,
2225 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2226 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2227 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2228 VectorToElementsLowering, VectorScalableStepOpLowering>(
2229 converter);
2230}
2231
2232namespace {
2233struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2235 void loadDependentDialects(MLIRContext *context) const final {
2236 context->loadDialect<LLVM::LLVMDialect>();
2237 }
2238
2239 /// Hook for derived dialect interface to provide conversion patterns
2240 /// and mark dialect legal for the conversion target.
2241 void populateConvertToLLVMConversionPatterns(
2242 ConversionTarget &target, LLVMTypeConverter &typeConverter,
2243 RewritePatternSet &patterns) const final {
2244 populateVectorToLLVMConversionPatterns(typeConverter, patterns);
2245 }
2246};
2247} // namespace
2248
2250 DialectRegistry &registry) {
2251 registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
2252 dialect->addInterfaces<VectorToLLVMDialectInterface>();
2253 });
2254}
return success()
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
lhs
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition Unit.cpp:18
#define mul(a, b)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
MLIRContext * getContext() const
Definition Builders.h:56
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition Pattern.h:330
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:603
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Generate IR that prints the given string to stdout.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition VectorOps.h:156
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
void registerConvertVectorToLLVMInterface(DialectRegistry &registry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:120
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.