MLIR 23.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Visitors.h"
24#include <cassert>
25#include <limits>
26#include <optional>
27
28#define DEBUG_TYPE "memref-to-spirv-pattern"
29
30using namespace mlir;
31
32//===----------------------------------------------------------------------===//
33// Utility functions
34//===----------------------------------------------------------------------===//
35
36/// Returns the offset of the value in `targetBits` representation.
37///
38/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39/// It's assumed to be non-negative.
40///
41/// When accessing an element in the array treating as having elements of
42/// `targetBits`, multiple values are loaded in the same time. The method
43/// returns the offset where the `srcIdx` locates in the value. For example, if
44/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45/// located at (x % 4) * 8. Because there are four elements in one i32, and one
46/// element has 8 bits.
47static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
48 int targetBits, OpBuilder &builder) {
49 assert(targetBits % sourceBits == 0);
50 Type type = srcIdx.getType();
51 IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54 auto srcBitsValue =
55 builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
58}
59
60/// Returns an adjusted spirv::AccessChainOp. Based on the
61/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62/// supported. During conversion if a memref of an unsupported type is used,
63/// load/stores to this memref need to be modified to use a supported higher
64/// bitwidth `targetBits` and extracting the required bits. For an accessing a
65/// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66/// the bits needed. The extraction of the actual bits needed are handled
67/// separately. Note that this only works for a 1-D tensor.
68static Value
70 spirv::AccessChainOp op, int sourceBits,
71 int targetBits, OpBuilder &builder) {
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
75 Type type = lastDim.getType();
76 IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
79 // There are two elements if this is a 1-D tensor.
80 assert(indices.size() == 2);
81 indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
84 indices);
85}
86
87/// Casts the given `srcBool` into an integer of `dstType`.
88static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
89 OpBuilder &builder) {
90 assert(srcBool.getType().isInteger(1));
91 if (dstType.isInteger(1))
92 return srcBool;
93 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
96 zero);
97}
98
99/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
100/// to the type destination type, and masked.
101static Value shiftValue(Location loc, Value value, Value offset, Value mask,
102 OpBuilder &builder) {
103 IntegerType dstType = cast<IntegerType>(mask.getType());
104 int targetBits = static_cast<int>(dstType.getWidth());
105 int valueBits = value.getType().getIntOrFloatBitWidth();
106 assert(valueBits <= targetBits);
107
108 if (valueBits == 1) {
109 value = castBoolToIntN(loc, value, dstType, builder);
110 } else {
111 if (valueBits < targetBits) {
112 value = spirv::UConvertOp::create(
113 builder, loc, builder.getIntegerType(targetBits), value);
114 }
115
116 value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
117 }
118 return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
119 value, offset);
120}
121
122/// Returns true if the allocations of memref `type` generated from `allocOp`
123/// can be lowered to SPIR-V.
124static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
128 return false;
129 } else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
132 return false;
133 } else {
134 return false;
135 }
136
137 // Currently only support static shape and int or float, complex of int or
138 // float, or vector of int or float element type.
139 if (!type.hasStaticShape())
140 return false;
141
142 Type elementType = type.getElementType();
143 if (auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
145 if (auto compType = dyn_cast<ComplexType>(elementType))
146 elementType = compType.getElementType();
147 return elementType.isIntOrFloat();
148}
149
150/// Returns the scope to use for atomic operations use for emulating store
151/// operations of unsupported integer bitwidths, based on the memref
152/// type. Returns std::nullopt on failure.
153static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
154 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
155 switch (sc.getValue()) {
156 case spirv::StorageClass::StorageBuffer:
157 return spirv::Scope::Device;
158 case spirv::StorageClass::Workgroup:
159 return spirv::Scope::Workgroup;
160 default:
161 break;
162 }
163 return {};
164}
165
166/// Extracts the element type from a SPIR-V pointer type pointing to storage.
167///
168/// For Kernel capability, the pointer points directly to the element type
169/// (possibly wrapped in an array). For Vulkan, the pointer points to a struct
170/// containing an array or runtime array, and we need to unwrap to get the
171/// element type.
172static Type
174 const SPIRVTypeConverter &typeConverter) {
175 if (typeConverter.allows(spirv::Capability::Kernel)) {
176 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
177 return arrayType.getElementType();
178 return pointeeType;
179 }
180 // For Vulkan we need to extract element from wrapping struct and array.
181 Type structElemType = cast<spirv::StructType>(pointeeType).getElementType(0);
182 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
183 return arrayType.getElementType();
184 return cast<spirv::RuntimeArrayType>(structElemType).getElementType();
185}
186
187/// Casts the given `srcInt` into a boolean value.
188static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
189 if (srcInt.getType().isInteger(1))
190 return srcInt;
191
192 auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
193 return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
194}
195
196//===----------------------------------------------------------------------===//
197// Operation conversion
198//===----------------------------------------------------------------------===//
199
200// Note that DRR cannot be used for the patterns in this file: we may need to
201// convert type along the way, which requires ConversionPattern. DRR generates
202// normal RewritePattern.
203
204namespace {
205
206/// Converts memref.alloca to SPIR-V Function variables.
207class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
208public:
209 using Base::Base;
210
211 LogicalResult
212 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
213 ConversionPatternRewriter &rewriter) const override;
214};
215
216/// Converts an allocation operation to SPIR-V. Currently only supports lowering
217/// to Workgroup memory when the size is constant. Note that this pattern needs
218/// to be applied in a pass that runs at least at spirv.module scope since it
219/// wil ladd global variables into the spirv.module.
220class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
221public:
222 using Base::Base;
223
224 LogicalResult
225 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
226 ConversionPatternRewriter &rewriter) const override;
227};
228
229/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
230class AtomicRMWOpPattern final
231 : public OpConversionPattern<memref::AtomicRMWOp> {
232public:
233 using Base::Base;
234
235 LogicalResult
236 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
237 ConversionPatternRewriter &rewriter) const override;
238};
239
240/// Removed a deallocation if it is a supported allocation. Currently only
241/// removes deallocation if the memory space is workgroup memory.
242class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
243public:
244 using Base::Base;
245
246 LogicalResult
247 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
248 ConversionPatternRewriter &rewriter) const override;
249};
250
251/// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
252class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
253public:
254 using Base::Base;
255
256 LogicalResult
257 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
258 ConversionPatternRewriter &rewriter) const override;
259};
260
261/// Converts memref.load to spirv.Load + spirv.AccessChain.
262class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
263public:
264 using Base::Base;
265
266 LogicalResult
267 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
268 ConversionPatternRewriter &rewriter) const override;
269};
270
271/// Converts memref.load to spirv.Image + spirv.ImageFetch
272class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
273public:
274 using Base::Base;
275
276 LogicalResult
277 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter) const override;
279};
280
281/// Converts memref.store to spirv.Store on integers.
282class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
283public:
284 using Base::Base;
285
286 LogicalResult
287 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override;
289};
290
291/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
292class MemorySpaceCastOpPattern final
293 : public OpConversionPattern<memref::MemorySpaceCastOp> {
294public:
295 using Base::Base;
296
297 LogicalResult
298 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter) const override;
300};
301
302/// Converts memref.store to spirv.Store.
303class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
304public:
305 using Base::Base;
306
307 LogicalResult
308 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
309 ConversionPatternRewriter &rewriter) const override;
310};
311
312class ReinterpretCastPattern final
313 : public OpConversionPattern<memref::ReinterpretCastOp> {
314public:
315 using Base::Base;
316
317 LogicalResult
318 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
319 ConversionPatternRewriter &rewriter) const override;
320};
321
322class CastPattern final : public OpConversionPattern<memref::CastOp> {
323public:
324 using Base::Base;
325
326 LogicalResult
327 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
328 ConversionPatternRewriter &rewriter) const override {
329 Value src = adaptor.getSource();
330 Type srcType = src.getType();
331
332 const TypeConverter *converter = getTypeConverter();
333 Type dstType = converter->convertType(op.getType());
334 if (srcType != dstType)
335 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
336 diag << "types doesn't match: " << srcType << " and " << dstType;
337 });
338
339 rewriter.replaceOp(op, src);
340 return success();
341 }
342};
343
344/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
345class ExtractAlignedPointerAsIndexOpPattern final
346 : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
347public:
348 using Base::Base;
349
350 LogicalResult
351 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
352 OpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter) const override;
354};
355} // namespace
356
357//===----------------------------------------------------------------------===//
358// AllocaOp
359//===----------------------------------------------------------------------===//
360
361LogicalResult
362AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter) const {
364 MemRefType allocType = allocaOp.getType();
365 if (!isAllocationSupported(allocaOp, allocType))
366 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
367
368 // Get the SPIR-V type for the allocation.
369 Type spirvType = getTypeConverter()->convertType(allocType);
370 if (!spirvType)
371 return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
372
373 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
374 spirv::StorageClass::Function,
375 /*initializer=*/nullptr);
376 return success();
377}
378
379//===----------------------------------------------------------------------===//
380// AllocOp
381//===----------------------------------------------------------------------===//
382
383LogicalResult
384AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
385 ConversionPatternRewriter &rewriter) const {
386 MemRefType allocType = operation.getType();
387 if (!isAllocationSupported(operation, allocType))
388 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
389
390 // Get the SPIR-V type for the allocation.
391 Type spirvType = getTypeConverter()->convertType(allocType);
392 if (!spirvType)
393 return rewriter.notifyMatchFailure(operation, "type conversion failed");
394
395 // Insert spirv.GlobalVariable for this allocation.
396 Operation *parent =
397 SymbolTable::getNearestSymbolTable(operation->getParentOp());
398 if (!parent)
399 return failure();
400 Location loc = operation.getLoc();
401 spirv::GlobalVariableOp varOp;
402 {
403 OpBuilder::InsertionGuard guard(rewriter);
404 Block &entryBlock = *parent->getRegion(0).begin();
405 rewriter.setInsertionPointToStart(&entryBlock);
406 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
407 std::string varName =
408 std::string("__workgroup_mem__") +
409 std::to_string(std::distance(varOps.begin(), varOps.end()));
410 varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
411 /*initializer=*/nullptr);
412 }
413
414 // Get pointer to global variable at the current scope.
415 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
416 return success();
417}
418
419//===----------------------------------------------------------------------===//
420// AllocOp
421//===----------------------------------------------------------------------===//
422
423LogicalResult
424AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
425 OpAdaptor adaptor,
426 ConversionPatternRewriter &rewriter) const {
427 if (isa<FloatType>(atomicOp.getType()))
428 return rewriter.notifyMatchFailure(atomicOp,
429 "unimplemented floating-point case");
430
431 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
432 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
433 if (!scope)
434 return rewriter.notifyMatchFailure(atomicOp,
435 "unsupported memref memory space");
436
437 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
438 Type resultType = typeConverter.convertType(atomicOp.getType());
439 if (!resultType)
440 return rewriter.notifyMatchFailure(atomicOp,
441 "failed to convert result type");
442
443 auto loc = atomicOp.getLoc();
444 Value ptr =
445 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
446 adaptor.getIndices(), loc, rewriter);
447
448 if (!ptr)
449 return failure();
450
451 // Determine the source and destination bitwidths. The source is the original
452 // memref element type and the destination is the SPIR-V storage type (e.g.,
453 // i32 for Vulkan).
454 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
455 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
456 if (!pointerType)
457 return rewriter.notifyMatchFailure(atomicOp,
458 "failed to convert memref type");
459
460 Type pointeeType = pointerType.getPointeeType();
461 auto dstType = dyn_cast<IntegerType>(
462 getElementTypeForStoragePointer(pointeeType, typeConverter));
463 if (!dstType)
464 return rewriter.notifyMatchFailure(
465 atomicOp, "failed to determine destination element type");
466
467 int dstBits = static_cast<int>(dstType.getWidth());
468 assert(dstBits % srcBits == 0);
469
470 // When the source and destination bitwidths match, emit the atomic operation
471 // directly.
472 if (srcBits == dstBits) {
473#define ATOMIC_CASE(kind, spirvOp) \
474 case arith::AtomicRMWKind::kind: \
475 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
476 atomicOp, resultType, ptr, *scope, \
477 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
478 break
479
480 switch (atomicOp.getKind()) {
481 ATOMIC_CASE(addi, AtomicIAddOp);
482 ATOMIC_CASE(maxs, AtomicSMaxOp);
483 ATOMIC_CASE(maxu, AtomicUMaxOp);
484 ATOMIC_CASE(mins, AtomicSMinOp);
485 ATOMIC_CASE(minu, AtomicUMinOp);
486 ATOMIC_CASE(ori, AtomicOrOp);
487 ATOMIC_CASE(andi, AtomicAndOp);
488 default:
489 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
490 }
491
492#undef ATOMIC_CASE
493
494 return success();
495 }
496
497 // Sub-element-width atomic: the element type (e.g., i8) is narrower than the
498 // storage type (e.g., i32). We need to adjust the index and shift/mask the
499 // value to operate on the correct bits within the wider storage element.
500 //
501 // Only ori and andi can be emulated because they operate bitwise and don't
502 // carry across byte boundaries. Other kinds (addi, max, min) would require
503 // CAS loops.
504 if (atomicOp.getKind() != arith::AtomicRMWKind::ori &&
505 atomicOp.getKind() != arith::AtomicRMWKind::andi) {
506 return rewriter.notifyMatchFailure(
507 atomicOp,
508 "atomic op on sub-element-width types is only supported for ori/andi");
509 }
510
511 // Bitcasting is currently unsupported for Kernel capability /
512 // spirv.PtrAccessChain.
513 if (typeConverter.allows(spirv::Capability::Kernel))
514 return rewriter.notifyMatchFailure(
515 atomicOp,
516 "sub-element-width atomic ops unsupported with Kernel capability");
517
518 auto accessChainOp = ptr.getDefiningOp<spirv::AccessChainOp>();
519 if (!accessChainOp)
520 return failure();
521
522 // Compute the bit offset within the storage element and adjust the pointer
523 // to address the containing storage element.
524 assert(accessChainOp.getIndices().size() == 2);
525 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
526 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
527 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
528 srcBits, dstBits, rewriter);
529 Value result;
530 switch (atomicOp.getKind()) {
531 case arith::AtomicRMWKind::ori: {
532 // OR only sets bits, so shifting the value to the target position and
533 // ORing with zeros in other positions preserves the unaffected bits.
534 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
535 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
536 Value storeVal =
537 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
538 result = spirv::AtomicOrOp::create(
539 rewriter, loc, dstType, adjustedPtr, *scope,
540 spirv::MemorySemantics::AcquireRelease, storeVal);
541 break;
542 }
543 case arith::AtomicRMWKind::andi: {
544 // Build a mask that preserves all bits outside the target element
545 // and applies the operand mask to the target element.
546 // mask = (operand << offset) | ~(elemMask << offset)
547 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
548 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
549 Value storeVal =
550 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
551 Value shiftedElemMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
552 loc, dstType, elemMask, offset);
553 Value invertedElemMask =
554 rewriter.createOrFold<spirv::NotOp>(loc, dstType, shiftedElemMask);
555 Value mask = rewriter.createOrFold<spirv::BitwiseOrOp>(loc, storeVal,
556 invertedElemMask);
557 result = spirv::AtomicAndOp::create(
558 rewriter, loc, dstType, adjustedPtr, *scope,
559 spirv::MemorySemantics::AcquireRelease, mask);
560 break;
561 }
562 default:
563 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
564 }
565
566 // The atomic op returns the old value of the full storage element (e.g.,
567 // i32). Extract the original sub-element value from the correct position.
568 result = rewriter.createOrFold<spirv::ShiftRightLogicalOp>(loc, dstType,
569 result, offset);
570 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
571 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
572 result =
573 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
574 rewriter.replaceOp(atomicOp, result);
575
576 return success();
577}
578
579//===----------------------------------------------------------------------===//
580// DeallocOp
581//===----------------------------------------------------------------------===//
582
583LogicalResult
584DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
585 OpAdaptor adaptor,
586 ConversionPatternRewriter &rewriter) const {
587 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
588 if (!isAllocationSupported(operation, deallocType))
589 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
590 rewriter.eraseOp(operation);
591 return success();
592}
593
594//===----------------------------------------------------------------------===//
595// LoadOp
596//===----------------------------------------------------------------------===//
597
599 spirv::MemoryAccessAttr memoryAccess;
600 IntegerAttr alignment;
601};
602
603/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
604/// any.
605static FailureOr<MemoryRequirements>
606calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
607 uint64_t preferredAlignment) {
608 if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
609 return failure();
610 }
611
612 MLIRContext *ctx = accessedPtr.getContext();
613
614 auto memoryAccess = spirv::MemoryAccess::None;
615 if (isNontemporal) {
616 memoryAccess = spirv::MemoryAccess::Nontemporal;
617 }
618
619 auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
620 bool mayOmitAlignment =
621 !preferredAlignment &&
622 ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
623 if (mayOmitAlignment) {
624 if (memoryAccess == spirv::MemoryAccess::None) {
625 return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
626 }
627 return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
628 IntegerAttr{}};
629 }
630
631 // PhysicalStorageBuffers require the `Aligned` attribute.
632 // Other storage types may show an `Aligned` attribute.
633 std::optional<int64_t> sizeInBytes;
634 Type rawPointeeType = ptrType.getPointeeType();
635 if (auto scalarType = dyn_cast<spirv::ScalarType>(rawPointeeType)) {
636 // For scalar types, the alignment is determined by their size.
637 sizeInBytes = scalarType.getSizeInBytes();
638 } else if (auto vecType = dyn_cast<VectorType>(rawPointeeType)) {
639 // For vector element types, the alignment should equal the total size of
640 // the vector.
641 if (auto scalarElem =
642 dyn_cast<spirv::ScalarType>(vecType.getElementType())) {
643 if (auto elemSize = scalarElem.getSizeInBytes())
644 sizeInBytes = *elemSize * vecType.getNumElements();
645 }
646 }
647
648 if (!sizeInBytes.has_value())
649 return failure();
650
651 memoryAccess |= spirv::MemoryAccess::Aligned;
652 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
653 auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
654 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
655 return MemoryRequirements{memAccessAttr, alignment};
656}
657
658/// Given an accessed SPIR-V pointer and the original memref load/store
659/// `memAccess` op, calculates the alignment requirements, if any. Takes into
660/// account the alignment attributes applied to the load/store op.
661template <class LoadOrStoreOp>
662static FailureOr<MemoryRequirements>
663calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
664 static_assert(
665 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
666 "Must be called on either memref::LoadOp or memref::StoreOp");
667
668 return calculateMemoryRequirements(accessedPtr,
669 loadOrStoreOp.getNontemporal(),
670 loadOrStoreOp.getAlignment().value_or(0));
671}
672
673LogicalResult
674IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
675 ConversionPatternRewriter &rewriter) const {
676 auto loc = loadOp.getLoc();
677 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
678 if (!memrefType.getElementType().isSignlessInteger())
679 return failure();
680
681 auto memorySpaceAttr =
682 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
683 if (!memorySpaceAttr)
684 return rewriter.notifyMatchFailure(
685 loadOp, "missing memory space SPIR-V storage class attribute");
686
687 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
688 return rewriter.notifyMatchFailure(
689 loadOp,
690 "failed to lower memref in image storage class to storage buffer");
691
692 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
693 Value accessChain =
694 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
695 adaptor.getIndices(), loc, rewriter);
696
697 if (!accessChain)
698 return failure();
699
700 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
701 bool isBool = srcBits == 1;
702 if (isBool)
703 srcBits = typeConverter.getOptions().boolNumBits;
704
705 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
706 if (!pointerType)
707 return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
708
709 Type pointeeType = pointerType.getPointeeType();
710 Type dstType = getElementTypeForStoragePointer(pointeeType, typeConverter);
711 int dstBits = dstType.getIntOrFloatBitWidth();
712 assert(dstBits % srcBits == 0);
713
714 // If the rewritten load op has the same bit width, use the loading value
715 // directly.
716 if (srcBits == dstBits) {
717 auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
718 if (failed(memoryRequirements))
719 return rewriter.notifyMatchFailure(
720 loadOp, "failed to determine memory requirements");
721
722 auto [memoryAccess, alignment] = *memoryRequirements;
723 Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
724 memoryAccess, alignment);
725 if (isBool)
726 loadVal = castIntNToBool(loc, loadVal, rewriter);
727 rewriter.replaceOp(loadOp, loadVal);
728 return success();
729 }
730
731 // Bitcasting is currently unsupported for Kernel capability /
732 // spirv.PtrAccessChain.
733 if (typeConverter.allows(spirv::Capability::Kernel))
734 return failure();
735
736 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
737 if (!accessChainOp)
738 return failure();
739
740 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
741 // still returns a linearized accessing. If the accessing is not linearized,
742 // there will be offset issues.
743 assert(accessChainOp.getIndices().size() == 2);
744 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
745 srcBits, dstBits, rewriter);
746 auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
747 if (failed(memoryRequirements))
748 return rewriter.notifyMatchFailure(
749 loadOp, "failed to determine memory requirements");
750
751 auto [memoryAccess, alignment] = *memoryRequirements;
752 Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
753 memoryAccess, alignment);
754
755 // Shift the bits to the rightmost.
756 // ____XXXX________ -> ____________XXXX
757 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
758 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
759 Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
760 loc, spvLoadOp.getType(), spvLoadOp, offset);
761
762 // Apply the mask to extract corresponding bits.
763 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
764 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
765 result =
766 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
767
768 // Apply sign extension on the loading value unconditionally. The signedness
769 // semantic is carried in the operator itself, we relies other pattern to
770 // handle the casting.
771 IntegerAttr shiftValueAttr =
772 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
773 Value shiftValue =
774 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
775 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
777 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
778 loc, dstType, result, shiftValue);
779
780 rewriter.replaceOp(loadOp, result);
781
782 assert(accessChainOp.use_empty());
783 rewriter.eraseOp(accessChainOp);
784
785 return success();
786}
787
788LogicalResult
789LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
790 ConversionPatternRewriter &rewriter) const {
791 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
792 if (memrefType.getElementType().isSignlessInteger())
793 return failure();
794
795 auto memorySpaceAttr =
796 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
797 if (!memorySpaceAttr)
798 return rewriter.notifyMatchFailure(
799 loadOp, "missing memory space SPIR-V storage class attribute");
800
801 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
802 return rewriter.notifyMatchFailure(
803 loadOp,
804 "failed to lower memref in image storage class to storage buffer");
805
806 Value loadPtr = spirv::getElementPtr(
807 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
808 adaptor.getIndices(), loadOp.getLoc(), rewriter);
809
810 if (!loadPtr)
811 return failure();
812
813 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
814 if (failed(memoryRequirements))
815 return rewriter.notifyMatchFailure(
816 loadOp, "failed to determine memory requirements");
817
818 auto [memoryAccess, alignment] = *memoryRequirements;
819 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
820 alignment);
821 return success();
822}
823
824template <typename OpAdaptor>
825static FailureOr<SmallVector<Value>>
826extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor,
827 ConversionPatternRewriter &rewriter) {
828 // At present we only support linear "tiling" as specified in Vulkan, this
829 // means that texels are assumed to be laid out in memory in a row-major
830 // order. This allows us to support any memref layout that is a permutation of
831 // the dimensions. Future work will pass an optional image layout to the
832 // rewrite pattern so that we can support optimized target specific tilings.
833 SmallVector<Value> indices = adaptor.getIndices();
834 AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
835 if (!map.isPermutation())
836 return rewriter.notifyMatchFailure(
837 loadOp,
838 "Cannot lower memrefs with memory layout which is not a permutation");
839
840 // The memrefs layout determines the dimension ordering so we need to follow
841 // the map to get the ordering of the dimensions/indices.
842 const unsigned dimCount = map.getNumDims();
843 SmallVector<Value, 3> coords(dimCount);
844 for (unsigned dim = 0; dim < dimCount; ++dim)
845 coords[map.getDimPosition(dim)] = indices[dim];
846
847 // We need to reverse the coordinates because the memref layout is slowest to
848 // fastest moving and the vector coordinates for the image op is fastest to
849 // slowest moving.
850 return llvm::to_vector(llvm::reverse(coords));
851}
852
853LogicalResult
854ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
855 ConversionPatternRewriter &rewriter) const {
856 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
857
858 auto memorySpaceAttr =
859 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
860 if (!memorySpaceAttr)
861 return rewriter.notifyMatchFailure(
862 loadOp, "missing memory space SPIR-V storage class attribute");
863
864 if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
865 return rewriter.notifyMatchFailure(
866 loadOp, "failed to lower memref in non-image storage class to image");
867
868 Value loadPtr = adaptor.getMemref();
869 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
870 if (failed(memoryRequirements))
871 return rewriter.notifyMatchFailure(
872 loadOp, "failed to determine memory requirements");
873
874 const auto [memoryAccess, alignment] = *memoryRequirements;
875
876 if (!loadOp.getMemRefType().hasRank())
877 return rewriter.notifyMatchFailure(
878 loadOp, "cannot lower unranked memrefs to SPIR-V images");
879
880 // We currently only support lowering of scalar memref elements to texels in
881 // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
882 // elements to texels in richer formats.
883 if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
884 return rewriter.notifyMatchFailure(
885 loadOp,
886 "cannot lower memrefs who's element type is not a SPIR-V scalar type"
887 "to SPIR-V images");
888
889 // We currently only support sampled images since OpImageFetch does not work
890 // for plain images and the OpImageRead instruction needs to be materialized
891 // instead or texels need to be accessed via atomics through a texel pointer.
892 // Future work will generalize support to plain images.
893 auto convertedPointeeType = cast<spirv::PointerType>(
894 getTypeConverter()->convertType(loadOp.getMemRefType()));
895 if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
896 return rewriter.notifyMatchFailure(loadOp,
897 "cannot lower memrefs which do not "
898 "convert to SPIR-V sampled images");
899
900 // Materialize the lowering.
901 Location loc = loadOp->getLoc();
902 auto imageLoadOp =
903 spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
904 // Extract the image from the sampled image.
905 auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
906
907 // Build a vector of coordinates or just a scalar index if we have a 1D image.
908 Value coords;
909 if (memrefType.getRank() == 1) {
910 coords = adaptor.getIndices()[0];
911 } else {
912 FailureOr<SmallVector<Value>> maybeCoords =
913 extractLoadCoordsForComposite(loadOp, adaptor, rewriter);
914 if (failed(maybeCoords))
915 return failure();
916 auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
917 adaptor.getIndices().getType()[0]);
918 coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
919 maybeCoords.value());
920 }
921
922 // Fetch the value out of the image.
923 auto resultVectorType = VectorType::get({4}, loadOp.getType());
924 auto fetchOp = spirv::ImageFetchOp::create(
925 rewriter, loc, resultVectorType, imageOp, coords,
926 mlir::spirv::ImageOperandsAttr{}, ValueRange{});
927
928 // Note that because OpImageFetch returns a rank 4 vector we need to extract
929 // the elements corresponding to the load which will since we only support the
930 // R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
931 auto compositeExtractOp =
932 spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
933
934 rewriter.replaceOp(loadOp, compositeExtractOp);
935 return success();
936}
937
938LogicalResult
939IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
940 ConversionPatternRewriter &rewriter) const {
941 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
942 if (!memrefType.getElementType().isSignlessInteger())
943 return rewriter.notifyMatchFailure(storeOp,
944 "element type is not a signless int");
945
946 auto loc = storeOp.getLoc();
947 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
948 Value accessChain =
949 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
950 adaptor.getIndices(), loc, rewriter);
951
952 if (!accessChain)
953 return rewriter.notifyMatchFailure(
954 storeOp, "failed to convert element pointer type");
955
956 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
957
958 bool isBool = srcBits == 1;
959 if (isBool)
960 srcBits = typeConverter.getOptions().boolNumBits;
961
962 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
963 if (!pointerType)
964 return rewriter.notifyMatchFailure(storeOp,
965 "failed to convert memref type");
966
967 Type pointeeType = pointerType.getPointeeType();
968 auto dstType = dyn_cast<IntegerType>(
969 getElementTypeForStoragePointer(pointeeType, typeConverter));
970 if (!dstType)
971 return rewriter.notifyMatchFailure(
972 storeOp, "failed to determine destination element type");
973
974 int dstBits = static_cast<int>(dstType.getWidth());
975 assert(dstBits % srcBits == 0);
976
977 if (srcBits == dstBits) {
978 auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
979 if (failed(memoryRequirements))
980 return rewriter.notifyMatchFailure(
981 storeOp, "failed to determine memory requirements");
982
983 auto [memoryAccess, alignment] = *memoryRequirements;
984 Value storeVal = adaptor.getValue();
985 if (isBool)
986 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
987 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
988 memoryAccess, alignment);
989 return success();
990 }
991
992 // Bitcasting is currently unsupported for Kernel capability /
993 // spirv.PtrAccessChain.
994 if (typeConverter.allows(spirv::Capability::Kernel))
995 return failure();
996
997 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
998 if (!accessChainOp)
999 return failure();
1000
1001 // Since there are multiple threads in the processing, the emulation will be
1002 // done with atomic operations. E.g., if the stored value is i8, rewrite the
1003 // StoreOp to:
1004 // 1) load a 32-bit integer
1005 // 2) clear 8 bits in the loaded value
1006 // 3) set 8 bits in the loaded value
1007 // 4) store 32-bit value back
1008 //
1009 // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
1010 // loaded 32-bit value and the shifted 8-bit store value) as another atomic
1011 // step.
1012 assert(accessChainOp.getIndices().size() == 2);
1013 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
1014 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
1015
1016 // Create a mask to clear the destination. E.g., if it is the second i8 in
1017 // i32, 0xFFFF00FF is created.
1018 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
1019 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
1020 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
1021 loc, dstType, mask, offset);
1022 clearBitsMask =
1023 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
1024
1025 Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
1026 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
1027 srcBits, dstBits, rewriter);
1028 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
1029 if (!scope)
1030 return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
1031
1032 Value result = spirv::AtomicAndOp::create(
1033 rewriter, loc, dstType, adjustedPtr, *scope,
1034 spirv::MemorySemantics::AcquireRelease, clearBitsMask);
1035 result = spirv::AtomicOrOp::create(
1036 rewriter, loc, dstType, adjustedPtr, *scope,
1037 spirv::MemorySemantics::AcquireRelease, storeVal);
1038
1039 // The AtomicOrOp has no side effect. Since it is already inserted, we can
1040 // just remove the original StoreOp. Note that rewriter.replaceOp()
1041 // doesn't work because it only accepts that the numbers of result are the
1042 // same.
1043 rewriter.eraseOp(storeOp);
1044
1045 assert(accessChainOp.use_empty());
1046 rewriter.eraseOp(accessChainOp);
1047
1048 return success();
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// MemorySpaceCastOp
1053//===----------------------------------------------------------------------===//
1054
1055LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
1056 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
1057 ConversionPatternRewriter &rewriter) const {
1058 Location loc = addrCastOp.getLoc();
1059 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1060 if (!typeConverter.allows(spirv::Capability::Kernel))
1061 return rewriter.notifyMatchFailure(
1062 loc, "address space casts require kernel capability");
1063
1064 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
1065 if (!sourceType)
1066 return rewriter.notifyMatchFailure(
1067 loc, "SPIR-V lowering requires ranked memref types");
1068 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
1069
1070 auto sourceStorageClassAttr =
1071 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
1072 if (!sourceStorageClassAttr)
1073 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
1074 diag << "source address space " << sourceType.getMemorySpace()
1075 << " must be a SPIR-V storage class";
1076 });
1077 auto resultStorageClassAttr =
1078 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
1079 if (!resultStorageClassAttr)
1080 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
1081 diag << "result address space " << resultType.getMemorySpace()
1082 << " must be a SPIR-V storage class";
1083 });
1084
1085 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
1086 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
1087
1088 Value result = adaptor.getSource();
1089 Type resultPtrType = typeConverter.convertType(resultType);
1090 if (!resultPtrType)
1091 return rewriter.notifyMatchFailure(addrCastOp,
1092 "failed to convert memref type");
1093
1094 Type genericPtrType = resultPtrType;
1095 // SPIR-V doesn't have a general address space cast operation. Instead, it has
1096 // conversions to and from generic pointers. To implement the general case,
1097 // we use specific-to-generic conversions when the source class is not
1098 // generic. Then when the result storage class is not generic, we convert the
1099 // generic pointer (either the input on ar intermediate result) to that
1100 // class. This also means that we'll need the intermediate generic pointer
1101 // type if neither the source or destination have it.
1102 if (sourceSc != spirv::StorageClass::Generic &&
1103 resultSc != spirv::StorageClass::Generic) {
1104 Type intermediateType =
1105 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
1106 sourceType.getLayout(),
1107 rewriter.getAttr<spirv::StorageClassAttr>(
1108 spirv::StorageClass::Generic));
1109 genericPtrType = typeConverter.convertType(intermediateType);
1110 }
1111 if (sourceSc != spirv::StorageClass::Generic) {
1112 result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
1113 result);
1114 }
1115 if (resultSc != spirv::StorageClass::Generic) {
1116 result =
1117 spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result);
1118 }
1119 rewriter.replaceOp(addrCastOp, result);
1120 return success();
1121}
1122
1123LogicalResult
1124StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
1125 ConversionPatternRewriter &rewriter) const {
1126 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
1127 if (memrefType.getElementType().isSignlessInteger())
1128 return rewriter.notifyMatchFailure(storeOp, "signless int");
1129 auto storePtr = spirv::getElementPtr(
1130 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
1131 adaptor.getIndices(), storeOp.getLoc(), rewriter);
1132
1133 if (!storePtr)
1134 return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
1135
1136 auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
1137 if (failed(memoryRequirements))
1138 return rewriter.notifyMatchFailure(
1139 storeOp, "failed to determine memory requirements");
1140
1141 auto [memoryAccess, alignment] = *memoryRequirements;
1142 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1143 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1144 return success();
1145}
1146
1147LogicalResult ReinterpretCastPattern::matchAndRewrite(
1148 memref::ReinterpretCastOp op, OpAdaptor adaptor,
1149 ConversionPatternRewriter &rewriter) const {
1150 Value src = adaptor.getSource();
1151 auto srcType = dyn_cast<spirv::PointerType>(src.getType());
1152
1153 if (!srcType)
1154 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1155 diag << "invalid src type " << src.getType();
1156 });
1157
1158 const TypeConverter *converter = getTypeConverter();
1159
1160 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1161 if (dstType != srcType)
1162 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1163 diag << "invalid dst type " << op.getType();
1164 });
1165
1166 OpFoldResult offset =
1167 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1168 .front();
1169 if (isZeroInteger(offset)) {
1170 rewriter.replaceOp(op, src);
1171 return success();
1172 }
1173
1174 Type intType = converter->convertType(rewriter.getIndexType());
1175 if (!intType)
1176 return rewriter.notifyMatchFailure(op, "failed to convert index type");
1177
1178 Location loc = op.getLoc();
1179 auto offsetValue = [&]() -> Value {
1180 if (auto val = dyn_cast<Value>(offset))
1181 return val;
1182
1183 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1184 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1185 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1186 }();
1187
1188 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1189 op, src, offsetValue, ValueRange());
1190 return success();
1191}
1192
1193//===----------------------------------------------------------------------===//
1194// ExtractAlignedPointerAsIndexOp
1195//===----------------------------------------------------------------------===//
1196
1197LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1198 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1199 ConversionPatternRewriter &rewriter) const {
1200 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1201 Type indexType = typeConverter.getIndexType();
1202 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1203 adaptor.getSource());
1204 return success();
1205}
1206
1207//===----------------------------------------------------------------------===//
1208// Pattern population
1209//===----------------------------------------------------------------------===//
1210
1211namespace mlir {
1213 RewritePatternSet &patterns) {
1214 patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1215 DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1216 IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1217 StoreOpPattern, ReinterpretCastPattern, CastPattern,
1218 ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
1219 patterns.getContext());
1220}
1221} // namespace mlir
return success()
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Type getElementTypeForStoragePointer(Type pointeeType, const SPIRVTypeConverter &typeConverter)
Extracts the element type from a SPIR-V pointer type pointing to storage.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static FailureOr< SmallVector< Value > > extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
iterator begin()
Definition Region.h:55
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess