MLIR 23.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Visitors.h"
24#include <cassert>
25#include <limits>
26#include <optional>
27
28#define DEBUG_TYPE "memref-to-spirv-pattern"
29
30using namespace mlir;
31
32//===----------------------------------------------------------------------===//
33// Utility functions
34//===----------------------------------------------------------------------===//
35
36/// Returns the offset of the value in `targetBits` representation.
37///
38/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39/// It's assumed to be non-negative.
40///
41/// When accessing an element in the array treating as having elements of
42/// `targetBits`, multiple values are loaded in the same time. The method
43/// returns the offset where the `srcIdx` locates in the value. For example, if
44/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45/// located at (x % 4) * 8. Because there are four elements in one i32, and one
46/// element has 8 bits.
47static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
48 int targetBits, OpBuilder &builder) {
49 assert(targetBits % sourceBits == 0);
50 Type type = srcIdx.getType();
51 IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54 auto srcBitsValue =
55 builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
58}
59
60/// Returns an adjusted spirv::AccessChainOp. Based on the
61/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62/// supported. During conversion if a memref of an unsupported type is used,
63/// load/stores to this memref need to be modified to use a supported higher
64/// bitwidth `targetBits` and extracting the required bits. For an accessing a
65/// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66/// the bits needed. The extraction of the actual bits needed are handled
67/// separately. Note that this only works for a 1-D tensor.
68static Value
70 spirv::AccessChainOp op, int sourceBits,
71 int targetBits, OpBuilder &builder) {
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
75 Type type = lastDim.getType();
76 IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
79 // There are two elements if this is a 1-D tensor.
80 assert(indices.size() == 2);
81 indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
84 indices);
85}
86
87/// Casts the given `srcBool` into an integer of `dstType`.
88static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
89 OpBuilder &builder) {
90 assert(srcBool.getType().isInteger(1));
91 if (dstType.isInteger(1))
92 return srcBool;
93 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
96 zero);
97}
98
99/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
100/// to the type destination type, and masked.
101static Value shiftValue(Location loc, Value value, Value offset, Value mask,
102 OpBuilder &builder) {
103 IntegerType dstType = cast<IntegerType>(mask.getType());
104 int targetBits = static_cast<int>(dstType.getWidth());
105 int valueBits = value.getType().getIntOrFloatBitWidth();
106 assert(valueBits <= targetBits);
107
108 if (valueBits == 1) {
109 value = castBoolToIntN(loc, value, dstType, builder);
110 } else {
111 if (valueBits < targetBits) {
112 value = spirv::UConvertOp::create(
113 builder, loc, builder.getIntegerType(targetBits), value);
114 }
115
116 value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
117 }
118 return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
119 value, offset);
120}
121
122/// Returns true if the allocations of memref `type` generated from `allocOp`
123/// can be lowered to SPIR-V.
124static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
128 return false;
129 } else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
132 return false;
133 } else {
134 return false;
135 }
136
137 // Currently only support static shape and int or float, complex of int or
138 // float, or vector of int or float element type.
139 if (!type.hasStaticShape())
140 return false;
141
142 Type elementType = type.getElementType();
143 if (auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
145 if (auto compType = dyn_cast<ComplexType>(elementType))
146 elementType = compType.getElementType();
147 return elementType.isIntOrFloat();
148}
149
150/// Returns the scope to use for atomic operations use for emulating store
151/// operations of unsupported integer bitwidths, based on the memref
152/// type. Returns std::nullopt on failure.
153static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
154 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
155 switch (sc.getValue()) {
156 case spirv::StorageClass::StorageBuffer:
157 return spirv::Scope::Device;
158 case spirv::StorageClass::Workgroup:
159 return spirv::Scope::Workgroup;
160 default:
161 break;
162 }
163 return {};
164}
165
166/// Returns the MemorySemantics storage-class bit corresponding to `sc`.
167/// Per SPIR-V spec section 3.32 (Memory Semantics) this bit must be OR'd
168/// with the ordering bits (Acquire/Release/...) on atomic operations.
169static spirv::MemorySemantics
170getMemorySemanticsForStorageClass(spirv::StorageClass sc) {
171 switch (sc) {
172 case spirv::StorageClass::StorageBuffer:
173 case spirv::StorageClass::Uniform:
174 return spirv::MemorySemantics::UniformMemory;
175 case spirv::StorageClass::Workgroup:
176 return spirv::MemorySemantics::WorkgroupMemory;
177 case spirv::StorageClass::CrossWorkgroup:
178 return spirv::MemorySemantics::CrossWorkgroupMemory;
179 case spirv::StorageClass::AtomicCounter:
180 return spirv::MemorySemantics::AtomicCounterMemory;
181 case spirv::StorageClass::Image:
182 return spirv::MemorySemantics::ImageMemory;
183 default:
184 return spirv::MemorySemantics::None;
185 }
186}
187
188/// Returns the AcquireRelease memory semantics OR'd with the storage-class
189/// bit derived from the memory space of `type`.
190static spirv::MemorySemantics getAtomicAcqRelMemorySemantics(MemRefType type) {
191 auto sc = cast<spirv::StorageClassAttr>(type.getMemorySpace()).getValue();
192 return spirv::MemorySemantics::AcquireRelease |
194}
195
196/// Extracts the element type from a SPIR-V pointer type pointing to storage.
197///
198/// For Kernel capability, the pointer points directly to the element type
199/// (possibly wrapped in an array). For Vulkan, the pointer points to a struct
200/// containing an array or runtime array, and we need to unwrap to get the
201/// element type.
202static Type
204 const SPIRVTypeConverter &typeConverter) {
205 if (typeConverter.allows(spirv::Capability::Kernel)) {
206 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
207 return arrayType.getElementType();
208 return pointeeType;
209 }
210 // For Vulkan we need to extract element from wrapping struct and array.
211 Type structElemType = cast<spirv::StructType>(pointeeType).getElementType(0);
212 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
213 return arrayType.getElementType();
214 return cast<spirv::RuntimeArrayType>(structElemType).getElementType();
215}
216
217/// Casts the given `srcInt` into a boolean value.
218static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
219 if (srcInt.getType().isInteger(1))
220 return srcInt;
221
222 auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
223 return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
224}
225
226//===----------------------------------------------------------------------===//
227// Operation conversion
228//===----------------------------------------------------------------------===//
229
230// Note that DRR cannot be used for the patterns in this file: we may need to
231// convert type along the way, which requires ConversionPattern. DRR generates
232// normal RewritePattern.
233
234namespace {
235
236/// Converts memref.alloca to SPIR-V Function variables.
237class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
238public:
239 using Base::Base;
240
241 LogicalResult
242 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
243 ConversionPatternRewriter &rewriter) const override;
244};
245
246/// Converts an allocation operation to SPIR-V. Currently only supports lowering
247/// to Workgroup memory when the size is constant. Note that this pattern needs
248/// to be applied in a pass that runs at least at spirv.module scope since it
249/// wil ladd global variables into the spirv.module.
250class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
251public:
252 using Base::Base;
253
254 LogicalResult
255 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter) const override;
257};
258
259/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
260class AtomicRMWOpPattern final
261 : public OpConversionPattern<memref::AtomicRMWOp> {
262public:
263 using Base::Base;
264
265 LogicalResult
266 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
267 ConversionPatternRewriter &rewriter) const override;
268};
269
270/// Removed a deallocation if it is a supported allocation. Currently only
271/// removes deallocation if the memory space is workgroup memory.
272class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
273public:
274 using Base::Base;
275
276 LogicalResult
277 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter) const override;
279};
280
281/// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
282class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
283public:
284 using Base::Base;
285
286 LogicalResult
287 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override;
289};
290
291/// Converts memref.load to spirv.Load + spirv.AccessChain.
292class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
293public:
294 using Base::Base;
295
296 LogicalResult
297 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter) const override;
299};
300
301/// Converts memref.load to spirv.Image + spirv.ImageFetch
302class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
303public:
304 using Base::Base;
305
306 LogicalResult
307 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter) const override;
309};
310
311/// Converts memref.store to spirv.Store on integers.
312class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
313public:
314 using Base::Base;
315
316 LogicalResult
317 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
318 ConversionPatternRewriter &rewriter) const override;
319};
320
321/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
322class MemorySpaceCastOpPattern final
323 : public OpConversionPattern<memref::MemorySpaceCastOp> {
324public:
325 using Base::Base;
326
327 LogicalResult
328 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter) const override;
330};
331
332/// Converts memref.store to spirv.Store.
333class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
334public:
335 using Base::Base;
336
337 LogicalResult
338 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter) const override;
340};
341
342class ReinterpretCastPattern final
343 : public OpConversionPattern<memref::ReinterpretCastOp> {
344public:
345 using Base::Base;
346
347 LogicalResult
348 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override;
350};
351
352class CastPattern final : public OpConversionPattern<memref::CastOp> {
353public:
354 using Base::Base;
355
356 LogicalResult
357 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const override {
359 Value src = adaptor.getSource();
360 Type srcType = src.getType();
361
362 const TypeConverter *converter = getTypeConverter();
363 Type dstType = converter->convertType(op.getType());
364 if (srcType != dstType)
365 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
366 diag << "types doesn't match: " << srcType << " and " << dstType;
367 });
368
369 rewriter.replaceOp(op, src);
370 return success();
371 }
372};
373
374/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
375class ExtractAlignedPointerAsIndexOpPattern final
376 : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
377public:
378 using Base::Base;
379
380 LogicalResult
381 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
382 OpAdaptor adaptor,
383 ConversionPatternRewriter &rewriter) const override;
384};
385} // namespace
386
387//===----------------------------------------------------------------------===//
388// AllocaOp
389//===----------------------------------------------------------------------===//
390
391LogicalResult
392AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
393 ConversionPatternRewriter &rewriter) const {
394 MemRefType allocType = allocaOp.getType();
395 if (!isAllocationSupported(allocaOp, allocType))
396 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
397
398 // Get the SPIR-V type for the allocation.
399 Type spirvType = getTypeConverter()->convertType(allocType);
400 if (!spirvType)
401 return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
402
403 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
404 spirv::StorageClass::Function,
405 /*initializer=*/nullptr);
406 return success();
407}
408
409//===----------------------------------------------------------------------===//
410// AllocOp
411//===----------------------------------------------------------------------===//
412
413LogicalResult
414AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
415 ConversionPatternRewriter &rewriter) const {
416 MemRefType allocType = operation.getType();
417 if (!isAllocationSupported(operation, allocType))
418 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
419
420 // Get the SPIR-V type for the allocation.
421 Type spirvType = getTypeConverter()->convertType(allocType);
422 if (!spirvType)
423 return rewriter.notifyMatchFailure(operation, "type conversion failed");
424
425 // Insert spirv.GlobalVariable for this allocation.
426 Operation *parent =
427 SymbolTable::getNearestSymbolTable(operation->getParentOp());
428 if (!parent)
429 return failure();
430 Location loc = operation.getLoc();
431 spirv::GlobalVariableOp varOp;
432 {
433 OpBuilder::InsertionGuard guard(rewriter);
434 Block &entryBlock = *parent->getRegion(0).begin();
435 rewriter.setInsertionPointToStart(&entryBlock);
436 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
437 std::string varName =
438 std::string("__workgroup_mem__") +
439 std::to_string(std::distance(varOps.begin(), varOps.end()));
440 varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
441 /*initializer=*/nullptr);
442 }
443
444 // Get pointer to global variable at the current scope.
445 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
446 return success();
447}
448
449//===----------------------------------------------------------------------===//
450// AllocOp
451//===----------------------------------------------------------------------===//
452
453LogicalResult
454AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
455 OpAdaptor adaptor,
456 ConversionPatternRewriter &rewriter) const {
457 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
458 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
459 if (!scope)
460 return rewriter.notifyMatchFailure(atomicOp,
461 "unsupported memref memory space");
462
463 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
464 Type resultType = typeConverter.convertType(atomicOp.getType());
465 if (!resultType)
466 return rewriter.notifyMatchFailure(atomicOp,
467 "failed to convert result type");
468
469 auto loc = atomicOp.getLoc();
470 Value ptr =
471 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
472 adaptor.getIndices(), loc, rewriter);
473
474 if (!ptr)
475 return failure();
476
477 // Determine the source and destination bitwidths. The source is the original
478 // memref element type and the destination is the SPIR-V storage type (e.g.,
479 // i32 for Vulkan).
480 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
481 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
482 if (!pointerType)
483 return rewriter.notifyMatchFailure(atomicOp,
484 "failed to convert memref type");
485
486 Type pointeeType = pointerType.getPointeeType();
487 Type storageElemType =
488 getElementTypeForStoragePointer(pointeeType, typeConverter);
489 if (!storageElemType || !storageElemType.isIntOrFloat())
490 return rewriter.notifyMatchFailure(
491 atomicOp, "failed to determine destination element type");
492
493 int dstBits = static_cast<int>(storageElemType.getIntOrFloatBitWidth());
494 assert(dstBits % srcBits == 0);
495
496 spirv::MemorySemantics memSem = getAtomicAcqRelMemorySemantics(memrefType);
497
498 // When the source and destination bitwidths match, emit the atomic operation
499 // directly.
500 if (srcBits == dstBits) {
501#define ATOMIC_CASE(kind, spirvOp) \
502 case arith::AtomicRMWKind::kind: \
503 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
504 atomicOp, resultType, ptr, *scope, memSem, adaptor.getValue()); \
505 break
506
507 switch (atomicOp.getKind()) {
508 ATOMIC_CASE(addf, EXTAtomicFAddOp);
509 ATOMIC_CASE(addi, AtomicIAddOp);
510 ATOMIC_CASE(maxs, AtomicSMaxOp);
511 ATOMIC_CASE(maxu, AtomicUMaxOp);
512 ATOMIC_CASE(mins, AtomicSMinOp);
513 ATOMIC_CASE(minu, AtomicUMinOp);
514 ATOMIC_CASE(ori, AtomicOrOp);
515 ATOMIC_CASE(andi, AtomicAndOp);
516 default:
517 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
518 }
519
520#undef ATOMIC_CASE
521
522 return success();
523 }
524
525 // Sub-element-width atomic: the element type (e.g., i8) is narrower than the
526 // storage type (e.g., i32). We need to adjust the index and shift/mask the
527 // value to operate on the correct bits within the wider storage element.
528 //
529 // Only ori and andi can be emulated because they operate bitwise and don't
530 // carry across byte boundaries. Other kinds (addi, max, min) would require
531 // CAS loops.
532 if (atomicOp.getKind() != arith::AtomicRMWKind::ori &&
533 atomicOp.getKind() != arith::AtomicRMWKind::andi) {
534 return rewriter.notifyMatchFailure(
535 atomicOp,
536 "atomic op on sub-element-width types is only supported for ori/andi");
537 }
538
539 // Bitcasting is currently unsupported for Kernel capability /
540 // spirv.PtrAccessChain.
541 if (typeConverter.allows(spirv::Capability::Kernel))
542 return rewriter.notifyMatchFailure(
543 atomicOp,
544 "sub-element-width atomic ops unsupported with Kernel capability");
545
546 auto dstType = cast<IntegerType>(storageElemType);
547
548 auto accessChainOp = ptr.getDefiningOp<spirv::AccessChainOp>();
549 if (!accessChainOp)
550 return failure();
551
552 // Compute the bit offset within the storage element and adjust the pointer
553 // to address the containing storage element.
554 assert(accessChainOp.getIndices().size() == 2);
555 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
556 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
557 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
558 srcBits, dstBits, rewriter);
559 Value result;
560 switch (atomicOp.getKind()) {
561 case arith::AtomicRMWKind::ori: {
562 // OR only sets bits, so shifting the value to the target position and
563 // ORing with zeros in other positions preserves the unaffected bits.
564 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
565 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
566 Value storeVal =
567 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
568 result = spirv::AtomicOrOp::create(rewriter, loc, dstType, adjustedPtr,
569 *scope, memSem, storeVal);
570 break;
571 }
572 case arith::AtomicRMWKind::andi: {
573 // Build a mask that preserves all bits outside the target element
574 // and applies the operand mask to the target element.
575 // mask = (operand << offset) | ~(elemMask << offset)
576 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
577 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
578 Value storeVal =
579 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
580 Value shiftedElemMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
581 loc, dstType, elemMask, offset);
582 Value invertedElemMask =
583 rewriter.createOrFold<spirv::NotOp>(loc, dstType, shiftedElemMask);
584 Value mask = rewriter.createOrFold<spirv::BitwiseOrOp>(loc, storeVal,
585 invertedElemMask);
586 result = spirv::AtomicAndOp::create(rewriter, loc, dstType, adjustedPtr,
587 *scope, memSem, mask);
588 break;
589 }
590 default:
591 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
592 }
593
594 // The atomic op returns the old value of the full storage element (e.g.,
595 // i32). Extract the original sub-element value from the correct position.
596 result = rewriter.createOrFold<spirv::ShiftRightLogicalOp>(loc, dstType,
597 result, offset);
598 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
599 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
600 result =
601 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
602 rewriter.replaceOp(atomicOp, result);
603
604 return success();
605}
606
607//===----------------------------------------------------------------------===//
608// DeallocOp
609//===----------------------------------------------------------------------===//
610
611LogicalResult
612DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
613 OpAdaptor adaptor,
614 ConversionPatternRewriter &rewriter) const {
615 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
616 if (!isAllocationSupported(operation, deallocType))
617 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
618 rewriter.eraseOp(operation);
619 return success();
620}
621
622//===----------------------------------------------------------------------===//
623// LoadOp
624//===----------------------------------------------------------------------===//
625
627 spirv::MemoryAccessAttr memoryAccess;
628 IntegerAttr alignment;
629};
630
631/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
632/// any.
633static FailureOr<MemoryRequirements>
634calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
635 uint64_t preferredAlignment) {
636 if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
637 return failure();
638 }
639
640 MLIRContext *ctx = accessedPtr.getContext();
641
642 auto memoryAccess = spirv::MemoryAccess::None;
643 if (isNontemporal) {
644 memoryAccess = spirv::MemoryAccess::Nontemporal;
645 }
646
647 auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
648 bool mayOmitAlignment =
649 !preferredAlignment &&
650 ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
651 if (mayOmitAlignment) {
652 if (memoryAccess == spirv::MemoryAccess::None) {
653 return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
654 }
655 return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
656 IntegerAttr{}};
657 }
658
659 // PhysicalStorageBuffers require the `Aligned` attribute.
660 // Other storage types may show an `Aligned` attribute.
661 std::optional<int64_t> sizeInBytes;
662 Type rawPointeeType = ptrType.getPointeeType();
663 if (auto scalarType = dyn_cast<spirv::ScalarType>(rawPointeeType)) {
664 // For scalar types, the alignment is determined by their size.
665 sizeInBytes = scalarType.getSizeInBytes();
666 } else if (auto vecType = dyn_cast<VectorType>(rawPointeeType)) {
667 // For vector element types, the alignment should equal the total size of
668 // the vector.
669 if (auto scalarElem =
670 dyn_cast<spirv::ScalarType>(vecType.getElementType())) {
671 if (auto elemSize = scalarElem.getSizeInBytes())
672 sizeInBytes = *elemSize * vecType.getNumElements();
673 }
674 }
675
676 if (!sizeInBytes.has_value())
677 return failure();
678
679 memoryAccess |= spirv::MemoryAccess::Aligned;
680 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
681 auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
682 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
683 return MemoryRequirements{memAccessAttr, alignment};
684}
685
686/// Given an accessed SPIR-V pointer and the original memref load/store
687/// `memAccess` op, calculates the alignment requirements, if any. Takes into
688/// account the alignment attributes applied to the load/store op.
689template <class LoadOrStoreOp>
690static FailureOr<MemoryRequirements>
691calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
692 static_assert(
693 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
694 "Must be called on either memref::LoadOp or memref::StoreOp");
695
696 return calculateMemoryRequirements(accessedPtr,
697 loadOrStoreOp.getNontemporal(),
698 loadOrStoreOp.getAlignment().value_or(0));
699}
700
701LogicalResult
702IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
703 ConversionPatternRewriter &rewriter) const {
704 auto loc = loadOp.getLoc();
705 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
706 if (!memrefType.getElementType().isSignlessInteger())
707 return failure();
708
709 auto memorySpaceAttr =
710 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
711 if (!memorySpaceAttr)
712 return rewriter.notifyMatchFailure(
713 loadOp, "missing memory space SPIR-V storage class attribute");
714
715 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
716 return rewriter.notifyMatchFailure(
717 loadOp,
718 "failed to lower memref in image storage class to storage buffer");
719
720 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
721 Value accessChain =
722 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
723 adaptor.getIndices(), loc, rewriter);
724
725 if (!accessChain)
726 return failure();
727
728 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
729 bool isBool = srcBits == 1;
730 if (isBool)
731 srcBits = typeConverter.getOptions().boolNumBits;
732
733 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
734 if (!pointerType)
735 return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
736
737 Type pointeeType = pointerType.getPointeeType();
738 Type dstType = getElementTypeForStoragePointer(pointeeType, typeConverter);
739 int dstBits = dstType.getIntOrFloatBitWidth();
740 assert(dstBits % srcBits == 0);
741
742 // If the rewritten load op has the same bit width, use the loading value
743 // directly.
744 if (srcBits == dstBits) {
745 auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
746 if (failed(memoryRequirements))
747 return rewriter.notifyMatchFailure(
748 loadOp, "failed to determine memory requirements");
749
750 auto [memoryAccess, alignment] = *memoryRequirements;
751 Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
752 memoryAccess, alignment);
753 if (isBool)
754 loadVal = castIntNToBool(loc, loadVal, rewriter);
755 rewriter.replaceOp(loadOp, loadVal);
756 return success();
757 }
758
759 // Bitcasting is currently unsupported for Kernel capability /
760 // spirv.PtrAccessChain.
761 if (typeConverter.allows(spirv::Capability::Kernel))
762 return failure();
763
764 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
765 if (!accessChainOp)
766 return failure();
767
768 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
769 // still returns a linearized accessing. If the accessing is not linearized,
770 // there will be offset issues.
771 assert(accessChainOp.getIndices().size() == 2);
772 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
773 srcBits, dstBits, rewriter);
774 auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
775 if (failed(memoryRequirements))
776 return rewriter.notifyMatchFailure(
777 loadOp, "failed to determine memory requirements");
778
779 auto [memoryAccess, alignment] = *memoryRequirements;
780 Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
781 memoryAccess, alignment);
782
783 // Shift the bits to the rightmost.
784 // ____XXXX________ -> ____________XXXX
785 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
786 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
787 Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
788 loc, spvLoadOp.getType(), spvLoadOp, offset);
789
790 // Apply the mask to extract corresponding bits.
791 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
792 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
793 result =
794 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
795
796 // Apply sign extension on the loading value unconditionally. The signedness
797 // semantic is carried in the operator itself, we relies other pattern to
798 // handle the casting.
799 IntegerAttr shiftValueAttr =
800 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
801 Value shiftValue =
802 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
803 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
805 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
806 loc, dstType, result, shiftValue);
807
808 rewriter.replaceOp(loadOp, result);
809
810 assert(accessChainOp.use_empty());
811 rewriter.eraseOp(accessChainOp);
812
813 return success();
814}
815
816LogicalResult
817LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter) const {
819 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
820 if (memrefType.getElementType().isSignlessInteger())
821 return failure();
822
823 auto memorySpaceAttr =
824 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
825 if (!memorySpaceAttr)
826 return rewriter.notifyMatchFailure(
827 loadOp, "missing memory space SPIR-V storage class attribute");
828
829 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
830 return rewriter.notifyMatchFailure(
831 loadOp,
832 "failed to lower memref in image storage class to storage buffer");
833
834 Value loadPtr = spirv::getElementPtr(
835 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
836 adaptor.getIndices(), loadOp.getLoc(), rewriter);
837
838 if (!loadPtr)
839 return failure();
840
841 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
842 if (failed(memoryRequirements))
843 return rewriter.notifyMatchFailure(
844 loadOp, "failed to determine memory requirements");
845
846 auto [memoryAccess, alignment] = *memoryRequirements;
847 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
848 alignment);
849 return success();
850}
851
852template <typename OpAdaptor>
853static FailureOr<SmallVector<Value>>
854extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor,
855 ConversionPatternRewriter &rewriter) {
856 // At present we only support linear "tiling" as specified in Vulkan, this
857 // means that texels are assumed to be laid out in memory in a row-major
858 // order. This allows us to support any memref layout that is a permutation of
859 // the dimensions. Future work will pass an optional image layout to the
860 // rewrite pattern so that we can support optimized target specific tilings.
861 SmallVector<Value> indices = adaptor.getIndices();
862 AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
863 if (!map.isPermutation())
864 return rewriter.notifyMatchFailure(
865 loadOp,
866 "Cannot lower memrefs with memory layout which is not a permutation");
867
868 // The memrefs layout determines the dimension ordering so we need to follow
869 // the map to get the ordering of the dimensions/indices.
870 const unsigned dimCount = map.getNumDims();
871 SmallVector<Value, 3> coords(dimCount);
872 for (unsigned dim = 0; dim < dimCount; ++dim)
873 coords[map.getDimPosition(dim)] = indices[dim];
874
875 // We need to reverse the coordinates because the memref layout is slowest to
876 // fastest moving and the vector coordinates for the image op is fastest to
877 // slowest moving.
878 return llvm::to_vector(llvm::reverse(coords));
879}
880
881LogicalResult
882ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
883 ConversionPatternRewriter &rewriter) const {
884 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
885
886 auto memorySpaceAttr =
887 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
888 if (!memorySpaceAttr)
889 return rewriter.notifyMatchFailure(
890 loadOp, "missing memory space SPIR-V storage class attribute");
891
892 if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
893 return rewriter.notifyMatchFailure(
894 loadOp, "failed to lower memref in non-image storage class to image");
895
896 Value loadPtr = adaptor.getMemref();
897 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
898 if (failed(memoryRequirements))
899 return rewriter.notifyMatchFailure(
900 loadOp, "failed to determine memory requirements");
901
902 const auto [memoryAccess, alignment] = *memoryRequirements;
903
904 if (!loadOp.getMemRefType().hasRank())
905 return rewriter.notifyMatchFailure(
906 loadOp, "cannot lower unranked memrefs to SPIR-V images");
907
908 // We currently only support lowering of scalar memref elements to texels in
909 // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
910 // elements to texels in richer formats.
911 if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
912 return rewriter.notifyMatchFailure(
913 loadOp,
914 "cannot lower memrefs who's element type is not a SPIR-V scalar type"
915 "to SPIR-V images");
916
917 // We currently only support sampled images since OpImageFetch does not work
918 // for plain images and the OpImageRead instruction needs to be materialized
919 // instead or texels need to be accessed via atomics through a texel pointer.
920 // Future work will generalize support to plain images.
921 auto convertedPointeeType = cast<spirv::PointerType>(
922 getTypeConverter()->convertType(loadOp.getMemRefType()));
923 if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
924 return rewriter.notifyMatchFailure(loadOp,
925 "cannot lower memrefs which do not "
926 "convert to SPIR-V sampled images");
927
928 // Materialize the lowering.
929 Location loc = loadOp->getLoc();
930 auto imageLoadOp =
931 spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
932 // Extract the image from the sampled image.
933 auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
934
935 // Build a vector of coordinates or just a scalar index if we have a 1D image.
936 Value coords;
937 if (memrefType.getRank() == 1) {
938 coords = adaptor.getIndices()[0];
939 } else {
940 FailureOr<SmallVector<Value>> maybeCoords =
941 extractLoadCoordsForComposite(loadOp, adaptor, rewriter);
942 if (failed(maybeCoords))
943 return failure();
944 auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
945 adaptor.getIndices().getType()[0]);
946 coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
947 maybeCoords.value());
948 }
949
950 // Fetch the value out of the image.
951 auto resultVectorType = VectorType::get({4}, loadOp.getType());
952 auto fetchOp = spirv::ImageFetchOp::create(
953 rewriter, loc, resultVectorType, imageOp, coords,
954 mlir::spirv::ImageOperandsAttr{}, ValueRange{});
955
956 // Note that because OpImageFetch returns a rank 4 vector we need to extract
957 // the elements corresponding to the load which will since we only support the
958 // R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
959 auto compositeExtractOp =
960 spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
961
962 rewriter.replaceOp(loadOp, compositeExtractOp);
963 return success();
964}
965
966LogicalResult
967IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
968 ConversionPatternRewriter &rewriter) const {
969 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
970 if (!memrefType.getElementType().isSignlessInteger())
971 return rewriter.notifyMatchFailure(storeOp,
972 "element type is not a signless int");
973
974 auto loc = storeOp.getLoc();
975 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
976 Value accessChain =
977 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
978 adaptor.getIndices(), loc, rewriter);
979
980 if (!accessChain)
981 return rewriter.notifyMatchFailure(
982 storeOp, "failed to convert element pointer type");
983
984 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
985
986 bool isBool = srcBits == 1;
987 if (isBool)
988 srcBits = typeConverter.getOptions().boolNumBits;
989
990 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
991 if (!pointerType)
992 return rewriter.notifyMatchFailure(storeOp,
993 "failed to convert memref type");
994
995 Type pointeeType = pointerType.getPointeeType();
996 auto dstType = dyn_cast<IntegerType>(
997 getElementTypeForStoragePointer(pointeeType, typeConverter));
998 if (!dstType)
999 return rewriter.notifyMatchFailure(
1000 storeOp, "failed to determine destination element type");
1001
1002 int dstBits = static_cast<int>(dstType.getWidth());
1003 assert(dstBits % srcBits == 0);
1004
1005 if (srcBits == dstBits) {
1006 auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
1007 if (failed(memoryRequirements))
1008 return rewriter.notifyMatchFailure(
1009 storeOp, "failed to determine memory requirements");
1010
1011 auto [memoryAccess, alignment] = *memoryRequirements;
1012 Value storeVal = adaptor.getValue();
1013 if (isBool)
1014 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
1015 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
1016 memoryAccess, alignment);
1017 return success();
1018 }
1019
1020 // Bitcasting is currently unsupported for Kernel capability /
1021 // spirv.PtrAccessChain.
1022 if (typeConverter.allows(spirv::Capability::Kernel))
1023 return failure();
1024
1025 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
1026 if (!accessChainOp)
1027 return failure();
1028
1029 // Since there are multiple threads in the processing, the emulation will be
1030 // done with atomic operations. E.g., if the stored value is i8, rewrite the
1031 // StoreOp to:
1032 // 1) load a 32-bit integer
1033 // 2) clear 8 bits in the loaded value
1034 // 3) set 8 bits in the loaded value
1035 // 4) store 32-bit value back
1036 //
1037 // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
1038 // loaded 32-bit value and the shifted 8-bit store value) as another atomic
1039 // step.
1040 assert(accessChainOp.getIndices().size() == 2);
1041 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
1042 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
1043
1044 // Create a mask to clear the destination. E.g., if it is the second i8 in
1045 // i32, 0xFFFF00FF is created.
1046 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
1047 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
1048 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
1049 loc, dstType, mask, offset);
1050 clearBitsMask =
1051 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
1052
1053 Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
1054 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
1055 srcBits, dstBits, rewriter);
1056 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
1057 if (!scope)
1058 return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
1059
1060 spirv::MemorySemantics memSem = getAtomicAcqRelMemorySemantics(memrefType);
1061 Value result = spirv::AtomicAndOp::create(rewriter, loc, dstType, adjustedPtr,
1062 *scope, memSem, clearBitsMask);
1063 result = spirv::AtomicOrOp::create(rewriter, loc, dstType, adjustedPtr,
1064 *scope, memSem, storeVal);
1065
1066 // The AtomicOrOp has no side effect. Since it is already inserted, we can
1067 // just remove the original StoreOp. Note that rewriter.replaceOp()
1068 // doesn't work because it only accepts that the numbers of result are the
1069 // same.
1070 rewriter.eraseOp(storeOp);
1071
1072 assert(accessChainOp.use_empty());
1073 rewriter.eraseOp(accessChainOp);
1074
1075 return success();
1076}
1077
1078//===----------------------------------------------------------------------===//
1079// MemorySpaceCastOp
1080//===----------------------------------------------------------------------===//
1081
1082LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
1083 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
1084 ConversionPatternRewriter &rewriter) const {
1085 Location loc = addrCastOp.getLoc();
1086 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1087 if (!typeConverter.allows(spirv::Capability::Kernel))
1088 return rewriter.notifyMatchFailure(
1089 loc, "address space casts require kernel capability");
1090
1091 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
1092 if (!sourceType)
1093 return rewriter.notifyMatchFailure(
1094 loc, "SPIR-V lowering requires ranked memref types");
1095 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
1096
1097 auto sourceStorageClassAttr =
1098 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
1099 if (!sourceStorageClassAttr)
1100 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
1101 diag << "source address space " << sourceType.getMemorySpace()
1102 << " must be a SPIR-V storage class";
1103 });
1104 auto resultStorageClassAttr =
1105 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
1106 if (!resultStorageClassAttr)
1107 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
1108 diag << "result address space " << resultType.getMemorySpace()
1109 << " must be a SPIR-V storage class";
1110 });
1111
1112 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
1113 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
1114
1115 Value result = adaptor.getSource();
1116 Type resultPtrType = typeConverter.convertType(resultType);
1117 if (!resultPtrType)
1118 return rewriter.notifyMatchFailure(addrCastOp,
1119 "failed to convert memref type");
1120
1121 Type genericPtrType = resultPtrType;
1122 // SPIR-V doesn't have a general address space cast operation. Instead, it has
1123 // conversions to and from generic pointers. To implement the general case,
1124 // we use specific-to-generic conversions when the source class is not
1125 // generic. Then when the result storage class is not generic, we convert the
1126 // generic pointer (either the input on ar intermediate result) to that
1127 // class. This also means that we'll need the intermediate generic pointer
1128 // type if neither the source or destination have it.
1129 if (sourceSc != spirv::StorageClass::Generic &&
1130 resultSc != spirv::StorageClass::Generic) {
1131 Type intermediateType =
1132 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
1133 sourceType.getLayout(),
1134 rewriter.getAttr<spirv::StorageClassAttr>(
1135 spirv::StorageClass::Generic));
1136 genericPtrType = typeConverter.convertType(intermediateType);
1137 }
1138 if (sourceSc != spirv::StorageClass::Generic) {
1139 result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
1140 result);
1141 }
1142 if (resultSc != spirv::StorageClass::Generic) {
1143 result =
1144 spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result);
1145 }
1146 rewriter.replaceOp(addrCastOp, result);
1147 return success();
1148}
1149
1150LogicalResult
1151StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
1152 ConversionPatternRewriter &rewriter) const {
1153 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
1154 if (memrefType.getElementType().isSignlessInteger())
1155 return rewriter.notifyMatchFailure(storeOp, "signless int");
1156 auto storePtr = spirv::getElementPtr(
1157 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
1158 adaptor.getIndices(), storeOp.getLoc(), rewriter);
1159
1160 if (!storePtr)
1161 return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
1162
1163 auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
1164 if (failed(memoryRequirements))
1165 return rewriter.notifyMatchFailure(
1166 storeOp, "failed to determine memory requirements");
1167
1168 auto [memoryAccess, alignment] = *memoryRequirements;
1169 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1170 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1171 return success();
1172}
1173
1174LogicalResult ReinterpretCastPattern::matchAndRewrite(
1175 memref::ReinterpretCastOp op, OpAdaptor adaptor,
1176 ConversionPatternRewriter &rewriter) const {
1177 Value src = adaptor.getSource();
1178 auto srcType = dyn_cast<spirv::PointerType>(src.getType());
1179
1180 if (!srcType)
1181 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1182 diag << "invalid src type " << src.getType();
1183 });
1184
1185 const TypeConverter *converter = getTypeConverter();
1186
1187 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1188 if (dstType != srcType)
1189 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1190 diag << "invalid dst type " << op.getType();
1191 });
1192
1193 OpFoldResult offset =
1194 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1195 .front();
1196 if (isZeroInteger(offset)) {
1197 rewriter.replaceOp(op, src);
1198 return success();
1199 }
1200
1201 Type intType = converter->convertType(rewriter.getIndexType());
1202 if (!intType)
1203 return rewriter.notifyMatchFailure(op, "failed to convert index type");
1204
1205 Location loc = op.getLoc();
1206 auto offsetValue = [&]() -> Value {
1207 if (auto val = dyn_cast<Value>(offset))
1208 return val;
1209
1210 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1211 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1212 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1213 }();
1214
1215 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1216 op, src, offsetValue, ValueRange());
1217 return success();
1218}
1219
1220//===----------------------------------------------------------------------===//
1221// ExtractAlignedPointerAsIndexOp
1222//===----------------------------------------------------------------------===//
1223
1224LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1225 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1226 ConversionPatternRewriter &rewriter) const {
1227 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1228 Type indexType = typeConverter.getIndexType();
1229 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1230 adaptor.getSource());
1231 return success();
1232}
1233
1234//===----------------------------------------------------------------------===//
1235// Pattern population
1236//===----------------------------------------------------------------------===//
1237
1238namespace mlir {
1240 RewritePatternSet &patterns) {
1241 patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1242 DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1243 IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1244 StoreOpPattern, ReinterpretCastPattern, CastPattern,
1245 ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
1246 patterns.getContext());
1247}
1248} // namespace mlir
return success()
static spirv::MemorySemantics getMemorySemanticsForStorageClass(spirv::StorageClass sc)
Returns the MemorySemantics storage-class bit corresponding to sc.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Type getElementTypeForStoragePointer(Type pointeeType, const SPIRVTypeConverter &typeConverter)
Extracts the element type from a SPIR-V pointer type pointing to storage.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static FailureOr< SmallVector< Value > > extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
static spirv::MemorySemantics getAtomicAcqRelMemorySemantics(MemRefType type)
Returns the AcquireRelease memory semantics OR'd with the storage-class bit derived from the memory s...
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:711
iterator begin()
Definition Region.h:55
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess