MLIR 23.0.0git
Pattern.cpp
Go to the documentation of this file.
1//===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/AffineMap.h"
15#include "llvm/Support/CheckedArithmetic.h"
16#include "llvm/Support/MathExtras.h"
17
18using namespace mlir;
19
20//===----------------------------------------------------------------------===//
21// ConvertToLLVMPattern
22//===----------------------------------------------------------------------===//
23
25 StringRef rootOpName, MLIRContext *context,
26 const LLVMTypeConverter &typeConverter, PatternBenefit benefit)
27 : ConversionPattern(typeConverter, rootOpName, benefit, context) {}
28
30 return static_cast<const LLVMTypeConverter *>(
31 ConversionPattern::getTypeConverter());
32}
33
34LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
35 return *getTypeConverter()->getDialect();
36}
37
41
42Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
43 return IntegerType::get(&getTypeConverter()->getContext(),
44 getTypeConverter()->getPointerBitwidth(addressSpace));
45}
46
48 return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
49}
50
51Type ConvertToLLVMPattern::getPtrType(unsigned addressSpace) const {
52 return LLVM::LLVMPointerType::get(&getTypeConverter()->getContext(),
53 addressSpace);
54}
55
57
59 Location loc,
60 Type resultType,
61 int64_t value) {
62 return LLVM::ConstantOp::create(builder, loc, resultType,
63 builder.getIndexAttr(value));
64}
65
67 ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
68 Value memRefDesc, ValueRange indices,
69 LLVM::GEPNoWrapFlags noWrapFlags) const {
70 return LLVM::getStridedElementPtr(rewriter, loc, *getTypeConverter(), type,
71 memRefDesc, indices, noWrapFlags);
72}
73
74// Check if the MemRefType `type` is supported by the lowering. We currently
75// only support memrefs with identity maps.
77 MemRefType type) const {
78 if (!type.getLayout().isIdentity())
79 return false;
80 return static_cast<bool>(typeConverter->convertType(type));
81}
82
84 auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type);
85 if (failed(addressSpace))
86 return {};
87 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
88}
89
91 Location loc, MemRefType memRefType, ValueRange dynamicSizes,
92 ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
93 SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
94 assert(isConvertibleAndHasIdentityMaps(memRefType) &&
95 "layout maps must have been normalized away");
96 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
97 static_cast<ssize_t>(dynamicSizes.size()) &&
98 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
99
100 sizes.reserve(memRefType.getRank());
101 unsigned dynamicIndex = 0;
102 Type indexType = getIndexType();
103 for (int64_t size : memRefType.getShape()) {
104 sizes.push_back(
105 size == ShapedType::kDynamic
106 ? dynamicSizes[dynamicIndex++]
107 : createIndexAttrConstant(rewriter, loc, indexType, size));
108 }
109
110 // Strides: iterate sizes in reverse order and multiply.
111 int64_t stride = 1;
112 bool overflowed = false;
113 Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
114 strides.resize(memRefType.getRank());
115 for (auto i = memRefType.getRank(); i-- > 0;) {
116 strides[i] = overflowed ? LLVM::PoisonOp::create(rewriter, loc, indexType)
117 : runningStride;
118
119 int64_t staticSize = memRefType.getShape()[i];
120 bool useSizeAsStride = stride == 1;
121 if (staticSize == ShapedType::kDynamic)
122 stride = ShapedType::kDynamic;
123 if (stride != ShapedType::kDynamic) {
124 std::optional<int64_t> res = llvm::checkedMul(stride, staticSize);
125
126 if (!res)
127 overflowed = true;
128 else
129 stride = res.value();
130 }
131
132 if (overflowed)
133 runningStride = LLVM::PoisonOp::create(rewriter, loc, indexType);
134 else if (useSizeAsStride)
135 runningStride = sizes[i];
136 else if (stride == ShapedType::kDynamic)
137 runningStride =
138 LLVM::MulOp::create(rewriter, loc, runningStride, sizes[i]);
139 else
140 runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
141 }
142 if (sizeInBytes) {
143 // Buffer size in bytes.
144 Type elementType = typeConverter->convertType(memRefType.getElementType());
145 auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
146 Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
147 Value gepPtr = LLVM::GEPOp::create(rewriter, loc, elementPtrType,
148 elementType, nullPtr, runningStride);
149 size = LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
150 } else {
151 size = runningStride;
152 }
153}
154
156 Location loc, Type type, ConversionPatternRewriter &rewriter) const {
157 // Compute the size of an individual element. This emits the MLIR equivalent
158 // of the following sizeof(...) implementation in LLVM IR:
159 // %0 = getelementptr %elementType* null, %indexType 1
160 // %1 = ptrtoint %elementType* %0 to %indexType
161 // which is a common pattern of getting the size of a type in bytes.
162 Type llvmType = typeConverter->convertType(type);
163 auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
164 auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, convertedPtrType);
165 auto gep = LLVM::GEPOp::create(rewriter, loc, convertedPtrType, llvmType,
166 nullPtr, ArrayRef<LLVM::GEPArg>{1});
167 return LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gep);
168}
169
171 Location loc, MemRefType memRefType, ValueRange dynamicSizes,
172 ConversionPatternRewriter &rewriter) const {
173 assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
174 static_cast<ssize_t>(dynamicSizes.size()) &&
175 "dynamicSizes size doesn't match dynamic sizes count in memref shape");
176
177 Type indexType = getIndexType();
178 Value numElements = memRefType.getRank() == 0
179 ? createIndexAttrConstant(rewriter, loc, indexType, 1)
180 : nullptr;
181 unsigned dynamicIndex = 0;
182
183 // Compute the total number of memref elements.
184 for (int64_t staticSize : memRefType.getShape()) {
185 if (numElements) {
186 Value size =
187 staticSize == ShapedType::kDynamic
188 ? dynamicSizes[dynamicIndex++]
189 : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
190 numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
191 } else {
192 numElements =
193 staticSize == ShapedType::kDynamic
194 ? dynamicSizes[dynamicIndex++]
195 : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
196 }
197 }
198 return numElements;
199}
200
201/// Creates and populates the memref descriptor struct given all its fields.
203 Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
204 ArrayRef<Value> sizes, ArrayRef<Value> strides,
205 ConversionPatternRewriter &rewriter) const {
206 auto structType = typeConverter->convertType(memRefType);
207 auto memRefDescriptor = MemRefDescriptor::poison(rewriter, loc, structType);
208
209 // Field 1: Allocated pointer, used for malloc/free.
210 memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
211
212 // Field 2: Actual aligned pointer to payload.
213 memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
214
215 // Field 3: Offset in aligned pointer.
216 Type indexType = getIndexType();
217 memRefDescriptor.setOffset(
218 rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
219
220 // Fields 4: Sizes.
221 for (const auto &en : llvm::enumerate(sizes))
222 memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
223
224 // Field 5: Strides.
225 for (const auto &en : llvm::enumerate(strides))
226 memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
227
228 return memRefDescriptor;
229}
230
232 OpBuilder &builder, Location loc, UnrankedMemRefType memRefType,
233 Value operand, bool toDynamic) const {
234 // Convert memory space.
235 FailureOr<unsigned> addressSpace =
237 if (failed(addressSpace))
238 return {};
239
240 // Get frequently used types.
241 Type indexType = getTypeConverter()->getIndexType();
242
243 // Find the malloc and free, or declare them if necessary.
244 auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
245 FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
246 if (toDynamic) {
247 mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
248 if (failed(mallocFunc))
249 return {};
250 }
251 if (!toDynamic) {
252 freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
253 if (failed(freeFunc))
254 return {};
255 }
256
257 UnrankedMemRefDescriptor desc(operand);
259 builder, loc, *getTypeConverter(), desc, *addressSpace);
260
261 // Allocate memory, copy, and free the source if necessary.
262 Value memory = toDynamic
263 ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
264 allocationSize)
265 .getResult()
266 : LLVM::AllocaOp::create(builder, loc, getPtrType(),
267 IntegerType::get(getContext(), 8),
268 allocationSize,
269 /*alignment=*/0);
270 Value source = desc.memRefDescPtr(builder, loc);
271 LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
272 if (!toDynamic)
273 LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
274
275 // Create a new descriptor. The same descriptor can be returned multiple
276 // times, attempting to modify its pointer can lead to memory leaks
277 // (allocated twice and overwritten) or double frees (the caller does not
278 // know if the descriptor points to the same memory).
279 Type descriptorType = getTypeConverter()->convertType(memRefType);
280 if (!descriptorType)
281 return {};
282 auto updatedDesc =
283 UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
284 Value rank = desc.rank(builder, loc);
285 updatedDesc.setRank(builder, loc, rank);
286 updatedDesc.setMemRefDescPtr(builder, loc, memory);
287 return updatedDesc;
288}
289
291 OpBuilder &builder, Location loc, TypeRange origTypes,
292 SmallVectorImpl<Value> &operands, bool toDynamic) const {
293 assert(origTypes.size() == operands.size() &&
294 "expected as may original types as operands");
295 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
296 if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
297 Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType,
298 operands[i], toDynamic);
299 if (!updatedDesc)
300 return failure();
301 operands[i] = updatedDesc;
302 }
303 }
304 return success();
305}
306
307//===----------------------------------------------------------------------===//
308// Detail methods
309//===----------------------------------------------------------------------===//
310
311/// Replaces the given operation "op" with a new operation of type "targetOp"
312/// and given operands.
314 Operation *op, StringRef targetOp, ValueRange operands,
315 ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr,
316 const LLVMTypeConverter &typeConverter,
317 ConversionPatternRewriter &rewriter) {
318 unsigned numResults = op->getNumResults();
319
320 SmallVector<Type> resultTypes;
321 if (numResults != 0) {
322 resultTypes.push_back(
323 typeConverter.packOperationResults(op->getResultTypes()));
324 if (!resultTypes.back())
325 return failure();
326 }
327
328 // Create the operation through state since we don't know its C++ type.
329 OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
330 resultTypes, targetAttrs);
331 state.propertiesAttr = propertiesAttr;
332 Operation *newOp = rewriter.create(state);
333
334 // If the operation produced 0 or 1 result, return them immediately.
335 if (numResults == 0)
336 return rewriter.eraseOp(op), success();
337 if (numResults == 1)
338 return rewriter.replaceOp(op, newOp->getResult(0)), success();
339
340 // Otherwise, it had been converted to an operation producing a structure.
341 // Extract individual results from the structure and return them as list.
342 SmallVector<Value, 4> results;
343 results.reserve(numResults);
344 for (unsigned i = 0; i < numResults; ++i) {
345 results.push_back(LLVM::ExtractValueOp::create(rewriter, op->getLoc(),
346 newOp->getResult(0), i));
347 }
348 rewriter.replaceOp(op, results);
349 return success();
350}
351
353 Operation *op, StringRef intrinsic, ValueRange operands,
354 const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) {
355 auto loc = op->getLoc();
356
357 if (!llvm::all_of(operands, [](Value value) {
358 return LLVM::isCompatibleType(value.getType());
359 }))
360 return failure();
361
362 unsigned numResults = op->getNumResults();
363 Type resType;
364 if (numResults != 0)
365 resType = typeConverter.packOperationResults(op->getResultTypes());
366
367 auto callIntrOp = LLVM::CallIntrinsicOp::create(
368 rewriter, loc, resType, rewriter.getStringAttr(intrinsic), operands);
369 // Propagate attributes.
370 callIntrOp->setAttrs(op->getAttrDictionary());
371
372 if (numResults <= 1) {
373 // Directly replace the original op.
374 rewriter.replaceOp(op, callIntrOp);
375 return success();
376 }
377
378 // Extract individual results from packed structure and use them as
379 // replacements.
380 SmallVector<Value, 4> results;
381 results.reserve(numResults);
382 Value intrRes = callIntrOp.getResults();
383 for (unsigned i = 0; i < numResults; ++i)
384 results.push_back(LLVM::ExtractValueOp::create(rewriter, loc, intrRes, i));
385 rewriter.replaceOp(op, results);
386
387 return success();
388}
389
390static unsigned getBitWidth(Type type) {
391 if (type.isIntOrFloat())
392 return type.getIntOrFloatBitWidth();
393
394 auto vec = cast<VectorType>(type);
395 assert(!vec.isScalable() && "scalable vectors are not supported");
396 return vec.getNumElements() * getBitWidth(vec.getElementType());
397}
398
399/// Returns true if every leaf in `type` (recursing through LLVM arrays and
400/// structs) is either equal to `dstType` or has a fixed bit width.
401static bool isFixedSizeAggregate(Type type, Type dstType) {
402 if (type == dstType)
403 return true;
404 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
405 return isFixedSizeAggregate(arrayType.getElementType(), dstType);
406 if (auto structType = dyn_cast<LLVM::LLVMStructType>(type))
407 return llvm::all_of(structType.getBody(), [&](Type fieldType) {
408 return isFixedSizeAggregate(fieldType, dstType);
409 });
410 if (auto vecTy = dyn_cast<VectorType>(type))
411 return !vecTy.isScalable();
412 return type.isIntOrFloat();
413}
414
416 int32_t value) {
417 Type i32 = builder.getI32Type();
418 return LLVM::ConstantOp::create(builder, loc, i32, value);
419}
420
421/// Recursive implementation of decomposeValue. When
422/// `permitVariablySizedScalars` is false, callers must ensure
423/// isFixedSizeAggregate() holds before calling this.
424static void decomposeValueImpl(OpBuilder &builder, Location loc, Value src,
426 Type srcType = src.getType();
427 if (srcType == dstType) {
428 result.push_back(src);
429 return;
430 }
431
432 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(srcType)) {
433 for (auto i : llvm::seq(arrayType.getNumElements())) {
434 Value elem = LLVM::ExtractValueOp::create(builder, loc, src, i);
435 decomposeValueImpl(builder, loc, elem, dstType, result);
436 }
437 return;
438 }
439
440 if (auto structType = dyn_cast<LLVM::LLVMStructType>(srcType)) {
441 for (auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
442 Value field = LLVM::ExtractValueOp::create(builder, loc, src,
443 static_cast<int64_t>(i));
444 decomposeValueImpl(builder, loc, field, dstType, result);
445 }
446 return;
447 }
448
449 // Variably sized leaf types (e.g., ptr) — pass through as-is.
450 if (!srcType.isIntOrFloat() && !isa<VectorType>(srcType)) {
451 result.push_back(src);
452 return;
453 }
454
455 unsigned srcBitWidth = getBitWidth(srcType);
456 unsigned dstBitWidth = getBitWidth(dstType);
457 if (srcBitWidth == dstBitWidth) {
458 Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src);
459 result.push_back(cast);
460 return;
461 }
462
463 if (dstBitWidth > srcBitWidth) {
464 auto smallerInt = builder.getIntegerType(srcBitWidth);
465 if (srcType != smallerInt)
466 src = LLVM::BitcastOp::create(builder, loc, smallerInt, src);
467
468 auto largerInt = builder.getIntegerType(dstBitWidth);
469 Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src);
470 result.push_back(res);
471 return;
472 }
473 int64_t numElements = llvm::divideCeil(srcBitWidth, dstBitWidth);
474 int64_t roundedBitWidth = numElements * dstBitWidth;
475
476 // Pad out values that don't decompose evenly before creating a vector.
477 if (roundedBitWidth != srcBitWidth) {
478 auto srcInt = builder.getIntegerType(srcBitWidth);
479 if (srcType != srcInt)
480 src = LLVM::BitcastOp::create(builder, loc, srcInt, src);
481 auto roundedInt = builder.getIntegerType(roundedBitWidth);
482 src = LLVM::ZExtOp::create(builder, loc, roundedInt, src);
483 }
484
485 auto vecType = VectorType::get(numElements, dstType);
486 src = LLVM::BitcastOp::create(builder, loc, vecType, src);
487
488 for (auto i : llvm::seq(numElements)) {
489 Value idx = createI32Constant(builder, loc, i);
490 Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx);
491 result.push_back(elem);
492 }
493}
494
496 Value src, Type dstType,
498 bool permitVariablySizedScalars) {
499 // Check the type tree before emitting any IR, so that a failing pattern
500 // leaves the IR unmodified.
501 if (!permitVariablySizedScalars &&
502 !isFixedSizeAggregate(src.getType(), dstType))
503 return failure();
504
505 decomposeValueImpl(builder, loc, src, dstType, result);
506 return success();
507}
508
509/// Recursive implementation of composeValue. Consumes elements from `src`
510/// starting at `offset`, advancing it past the consumed elements.
512 size_t &offset, Type dstType) {
513 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(dstType)) {
514 Value result = LLVM::PoisonOp::create(builder, loc, arrayType);
515 Type elemType = arrayType.getElementType();
516 for (auto i : llvm::seq(arrayType.getNumElements())) {
517 Value elem = composeValueImpl(builder, loc, src, offset, elemType);
518 result = LLVM::InsertValueOp::create(builder, loc, result, elem, i);
519 }
520 return result;
521 }
522
523 if (auto structType = dyn_cast<LLVM::LLVMStructType>(dstType)) {
524 Value result = LLVM::PoisonOp::create(builder, loc, structType);
525 for (auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
526 Value field = composeValueImpl(builder, loc, src, offset, fieldType);
527 result = LLVM::InsertValueOp::create(builder, loc, result, field,
528 static_cast<int64_t>(i));
529 }
530 return result;
531 }
532
533 // Variably sized leaf types (e.g., ptr) — consume and return as-is.
534 if (!dstType.isIntOrFloat() && !isa<VectorType>(dstType))
535 return src[offset++];
536
537 unsigned dstBitWidth = getBitWidth(dstType);
538
539 Value front = src[offset];
540 if (front.getType() == dstType) {
541 ++offset;
542 return front;
543 }
544
545 // Single element wider than or equal to dst: bitcast/trunc.
546 if (front.getType().isIntOrFloat() || isa<VectorType>(front.getType())) {
547 unsigned srcBitWidth = getBitWidth(front.getType());
548 if (srcBitWidth >= dstBitWidth) {
549 ++offset;
550 Value res = front;
551 if (dstBitWidth < srcBitWidth) {
552 auto largerInt = builder.getIntegerType(srcBitWidth);
553 if (res.getType() != largerInt)
554 res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
555
556 auto smallerInt = builder.getIntegerType(dstBitWidth);
557 res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
558 }
559 if (res.getType() != dstType)
560 res = LLVM::BitcastOp::create(builder, loc, dstType, res);
561 return res;
562 }
563 }
564
565 // Multiple elements narrower than dst: gather into a vector and bitcast.
566 unsigned elemBitWidth = getBitWidth(front.getType());
567 int64_t numElements = llvm::divideCeil(dstBitWidth, elemBitWidth);
568 int64_t roundedBitWidth = numElements * elemBitWidth;
569
570 auto vecType = VectorType::get(numElements, front.getType());
571 Value res = LLVM::PoisonOp::create(builder, loc, vecType);
572 for (auto i : llvm::seq(numElements)) {
573 Value idx = createI32Constant(builder, loc, i);
574 res = LLVM::InsertElementOp::create(builder, loc, vecType, res,
575 src[offset++], idx);
576 }
577
578 // Undo any padding decomposition might have introduced.
579 if (roundedBitWidth != dstBitWidth) {
580 auto roundedInt = builder.getIntegerType(roundedBitWidth);
581 res = LLVM::BitcastOp::create(builder, loc, roundedInt, res);
582 auto dstInt = builder.getIntegerType(dstBitWidth);
583 res = LLVM::TruncOp::create(builder, loc, dstInt, res);
584 if (dstType != dstInt)
585 res = LLVM::BitcastOp::create(builder, loc, dstType, res);
586 } else {
587 if (res.getType() != dstType)
588 res = LLVM::BitcastOp::create(builder, loc, dstType, res);
589 }
590
591 return res;
592}
593
595 Type dstType) {
596 assert(!src.empty() && "src range must not be empty");
597 size_t offset = 0;
598 Value result = composeValueImpl(builder, loc, src, offset, dstType);
599 assert(offset == src.size() && "not all decomposed values were consumed");
600 return result;
601}
602
604 const LLVMTypeConverter &converter,
605 MemRefType type, Value memRefDesc,
607 LLVM::GEPNoWrapFlags noWrapFlags) {
608 auto [strides, offset] = type.getStridesAndOffset();
609
610 MemRefDescriptor memRefDescriptor(memRefDesc);
611 // Use a canonical representation of the start address so that later
612 // optimizations have a longer sequence of instructions to CSE.
613 // If we don't do that we would sprinkle the memref.offset in various
614 // position of the different address computations.
615 Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type);
616
617 LLVM::IntegerOverflowFlags intOverflowFlags =
618 LLVM::IntegerOverflowFlags::none;
619 if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
620 intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
621 }
622 if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
623 intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
624 }
625
626 Type indexType = converter.getIndexType();
627 Value index;
628 for (int i = 0, e = indices.size(); i < e; ++i) {
629 Value increment = indices[i];
630 if (strides[i] != 1) { // Skip if stride is 1.
631 Value stride =
632 ShapedType::isDynamic(strides[i])
633 ? memRefDescriptor.stride(builder, loc, i)
634 : LLVM::ConstantOp::create(builder, loc, indexType,
635 builder.getIndexAttr(strides[i]));
636 increment = LLVM::MulOp::create(builder, loc, increment, stride,
637 intOverflowFlags);
638 }
639 index = index ? LLVM::AddOp::create(builder, loc, index, increment,
640 intOverflowFlags)
641 : increment;
642 }
643
644 Type elementPtrType = memRefDescriptor.getElementPtrType();
645 return index
646 ? LLVM::GEPOp::create(builder, loc, elementPtrType,
647 converter.convertType(type.getElementType()),
648 base, index, noWrapFlags)
649 : base;
650}
651
652/// Return the given type if it's a floating point type. If the given type is
653/// a vector type, return its element type if it's a floating point type.
654static FloatType getFloatingPointType(Type type) {
655 if (auto floatType = dyn_cast<FloatType>(type))
656 return floatType;
657 if (auto vecType = dyn_cast<VectorType>(type))
658 return dyn_cast<FloatType>(vecType.getElementType());
659 return nullptr;
660}
661
663 const TypeConverter &typeConverter, Type type) {
664 FloatType floatType = getFloatingPointType(type);
665 if (!floatType)
666 return false;
667 Type convertedType = typeConverter.convertType(floatType);
668 if (!convertedType)
669 return true;
670 return !isa<FloatType>(convertedType);
671}
672
674 Operation *op, const TypeConverter &typeConverter) {
675 for (Value operand : op->getOperands())
676 if (isUnsupportedFloatingPointType(typeConverter, operand.getType()))
677 return true;
678 return llvm::any_of(op->getResults(), [&typeConverter](OpResult r) {
679 return isUnsupportedFloatingPointType(typeConverter, r.getType());
680 });
681}
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
return success()
static unsigned getBitWidth(Type type)
Definition Pattern.cpp:390
static FloatType getFloatingPointType(Type type)
Return the given type if it's a floating point type.
Definition Pattern.cpp:654
static bool isFixedSizeAggregate(Type type, Type dstType)
Returns true if every leaf in type (recursing through LLVM arrays and structs) is either equal to dst...
Definition Pattern.cpp:401
static Value composeValueImpl(OpBuilder &builder, Location loc, ValueRange src, size_t &offset, Type dstType)
Recursive implementation of composeValue.
Definition Pattern.cpp:511
static void decomposeValueImpl(OpBuilder &builder, Location loc, Value src, Type dstType, SmallVectorImpl< Value > &result)
Recursive implementation of decomposeValue.
Definition Pattern.cpp:424
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
Type getVoidType() const
Gets the MLIR type wrapping the LLVM void type.
Definition Pattern.cpp:47
MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef< Value > sizes, ArrayRef< Value > strides, ConversionPatternRewriter &rewriter) const
Creates and populates a canonical memref descriptor struct.
Definition Pattern.cpp:202
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.cpp:24
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
Definition Pattern.cpp:66
void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &strides, Value &size, bool sizeInBytes=true) const
Computes sizes, strides and buffer size of memRefType with identity layout.
Definition Pattern.cpp:90
Type getPtrType(unsigned addressSpace=0) const
Get the MLIR type wrapping the LLVM ptr type.
Definition Pattern.cpp:51
Type getIndexType() const
Gets the MLIR type wrapping the LLVM integer type whose bit width is defined by the used type convert...
Definition Pattern.cpp:38
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:29
Value getNumElements(Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter) const
Computes total number of elements for the given MemRef and dynamicSizes.
Definition Pattern.cpp:170
LLVM::LLVMDialect & getDialect() const
Returns the LLVM dialect.
Definition Pattern.cpp:34
Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const
Computes the size of type in bytes.
Definition Pattern.cpp:155
Type getIntPtrType(unsigned addressSpace=0) const
Gets the MLIR type wrapping the LLVM integer type whose bit width corresponds to that of a LLVM point...
Definition Pattern.cpp:42
Value copyUnrankedDescriptor(OpBuilder &builder, Location loc, UnrankedMemRefType memRefType, Value operand, bool toDynamic) const
Copies the given unranked memory descriptor to heap-allocated memory (if toDynamic is true) or to sta...
Definition Pattern.cpp:231
LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl< Value > &operands, bool toDynamic) const
Copies the memory descriptor for any operands that were unranked descriptors originally to heap-alloc...
Definition Pattern.cpp:290
Type getElementPtrType(MemRefType type) const
Returns the type of a pointer to an element of the memref.
Definition Pattern.cpp:83
static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value)
Create a constant Op producing a value of resultType from an index-typed integer attribute.
Definition Pattern.cpp:58
bool isConvertibleAndHasIdentityMaps(MemRefType type) const
Returns if the given memref type is convertible to LLVM and has an identity layout map.
Definition Pattern.cpp:76
Type getVoidPtrType() const
Get the MLIR type wrapping the LLVM i8* type.
Definition Pattern.cpp:56
Conversion from types to the LLVM IR dialect.
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Type getIndexType() const
Gets the LLVM representation of the index type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
This class helps build Operations.
Definition Builders.h:209
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
result_type_range getResultTypes()
Definition Operation.h:457
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
static Value computeSize(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, UnrankedMemRefDescriptor desc, unsigned addressSpace)
Builds and returns IR computing the size in bytes (suitable for opaque allocation).
Value memRefDescPtr(OpBuilder &builder, Location loc) const
Builds IR extracting ranked memref descriptor ptr.
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Value rank(OpBuilder &builder, Location loc) const
Builds IR extracting the rank from the descriptor.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, Type type)
Return "true" if the given type is an unsupported floating point type.
Definition Pattern.cpp:662
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition Pattern.cpp:313
bool opHasUnsupportedFloatingPointTypes(Operation *op, const TypeConverter &typeConverter)
Return "true" if the given op has any unsupported floating point types (either operands or results).
Definition Pattern.cpp:673
LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic, ValueRange operands, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
Replaces the given operation "op" with a call to an LLVM intrinsic with the specified name "intrinsic...
Definition Pattern.cpp:352
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:603
LogicalResult decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType, SmallVectorImpl< Value > &result, bool permitVariablySizedScalars=false)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
Definition Pattern.cpp:495
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
Definition Pattern.cpp:594
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Include the generated interface declarations.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.