MLIR 23.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Visitors.h"
24#include <cassert>
25#include <limits>
26#include <optional>
27
28#define DEBUG_TYPE "memref-to-spirv-pattern"
29
30using namespace mlir;
31
32//===----------------------------------------------------------------------===//
33// Utility functions
34//===----------------------------------------------------------------------===//
35
36/// Returns the offset of the value in `targetBits` representation.
37///
38/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39/// It's assumed to be non-negative.
40///
41/// When accessing an element in the array treating as having elements of
42/// `targetBits`, multiple values are loaded in the same time. The method
43/// returns the offset where the `srcIdx` locates in the value. For example, if
44/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45/// located at (x % 4) * 8. Because there are four elements in one i32, and one
46/// element has 8 bits.
47static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
48 int targetBits, OpBuilder &builder) {
49 assert(targetBits % sourceBits == 0);
50 Type type = srcIdx.getType();
51 IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54 auto srcBitsValue =
55 builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
58}
59
60/// Returns an adjusted spirv::AccessChainOp. Based on the
61/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62/// supported. During conversion if a memref of an unsupported type is used,
63/// load/stores to this memref need to be modified to use a supported higher
64/// bitwidth `targetBits` and extracting the required bits. For an accessing a
65/// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66/// the bits needed. The extraction of the actual bits needed are handled
67/// separately. Note that this only works for a 1-D tensor.
68static Value
70 spirv::AccessChainOp op, int sourceBits,
71 int targetBits, OpBuilder &builder) {
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
75 Type type = lastDim.getType();
76 IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
79 // There are two elements if this is a 1-D tensor.
80 assert(indices.size() == 2);
81 indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
84 indices);
85}
86
87/// Casts the given `srcBool` into an integer of `dstType`.
88static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
89 OpBuilder &builder) {
90 assert(srcBool.getType().isInteger(1));
91 if (dstType.isInteger(1))
92 return srcBool;
93 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
96 zero);
97}
98
99/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
100/// to the type destination type, and masked.
101static Value shiftValue(Location loc, Value value, Value offset, Value mask,
102 OpBuilder &builder) {
103 IntegerType dstType = cast<IntegerType>(mask.getType());
104 int targetBits = static_cast<int>(dstType.getWidth());
105 int valueBits = value.getType().getIntOrFloatBitWidth();
106 assert(valueBits <= targetBits);
107
108 if (valueBits == 1) {
109 value = castBoolToIntN(loc, value, dstType, builder);
110 } else {
111 if (valueBits < targetBits) {
112 value = spirv::UConvertOp::create(
113 builder, loc, builder.getIntegerType(targetBits), value);
114 }
115
116 value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
117 }
118 return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
119 value, offset);
120}
121
122/// Returns true if the allocations of memref `type` generated from `allocOp`
123/// can be lowered to SPIR-V.
124static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
128 return false;
129 } else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
132 return false;
133 } else {
134 return false;
135 }
136
137 // Currently only support static shape and int or float, complex of int or
138 // float, or vector of int or float element type.
139 if (!type.hasStaticShape())
140 return false;
141
142 Type elementType = type.getElementType();
143 if (auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
145 if (auto compType = dyn_cast<ComplexType>(elementType))
146 elementType = compType.getElementType();
147 return elementType.isIntOrFloat();
148}
149
150/// Returns the scope to use for atomic operations use for emulating store
151/// operations of unsupported integer bitwidths, based on the memref
152/// type. Returns std::nullopt on failure.
153static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
154 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
155 switch (sc.getValue()) {
156 case spirv::StorageClass::StorageBuffer:
157 return spirv::Scope::Device;
158 case spirv::StorageClass::Workgroup:
159 return spirv::Scope::Workgroup;
160 default:
161 break;
162 }
163 return {};
164}
165
166/// Returns the MemorySemantics storage-class bit corresponding to `sc`.
167/// Per SPIR-V spec section 3.32 (Memory Semantics) this bit must be OR'd
168/// with the ordering bits (Acquire/Release/...) on atomic operations.
169static spirv::MemorySemantics
170getMemorySemanticsForStorageClass(spirv::StorageClass sc) {
171 switch (sc) {
172 case spirv::StorageClass::StorageBuffer:
173 case spirv::StorageClass::Uniform:
174 return spirv::MemorySemantics::UniformMemory;
175 case spirv::StorageClass::Workgroup:
176 return spirv::MemorySemantics::WorkgroupMemory;
177 case spirv::StorageClass::CrossWorkgroup:
178 return spirv::MemorySemantics::CrossWorkgroupMemory;
179 case spirv::StorageClass::AtomicCounter:
180 return spirv::MemorySemantics::AtomicCounterMemory;
181 case spirv::StorageClass::Image:
182 return spirv::MemorySemantics::ImageMemory;
183 default:
184 return spirv::MemorySemantics::None;
185 }
186}
187
188/// Returns the AcquireRelease memory semantics OR'd with the storage-class
189/// bit derived from the memory space of `type`.
190static spirv::MemorySemantics getAtomicAcqRelMemorySemantics(MemRefType type) {
191 auto sc = cast<spirv::StorageClassAttr>(type.getMemorySpace()).getValue();
192 return spirv::MemorySemantics::AcquireRelease |
194}
195
196/// Extracts the element type from a SPIR-V pointer type pointing to storage.
197///
198/// For Kernel capability, the pointer points directly to the element type
199/// (possibly wrapped in an array). For Vulkan, the pointer points to a struct
200/// containing an array or runtime array, and we need to unwrap to get the
201/// element type.
202static Type
204 const SPIRVTypeConverter &typeConverter) {
205 if (typeConverter.allows(spirv::Capability::Kernel)) {
206 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
207 return arrayType.getElementType();
208 return pointeeType;
209 }
210 // For Vulkan we need to extract element from wrapping struct and array.
211 Type structElemType = cast<spirv::StructType>(pointeeType).getElementType(0);
212 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
213 return arrayType.getElementType();
214 return cast<spirv::RuntimeArrayType>(structElemType).getElementType();
215}
216
217/// Casts the given `srcInt` into a boolean value.
218static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
219 if (srcInt.getType().isInteger(1))
220 return srcInt;
221
222 auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
223 return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
224}
225
226//===----------------------------------------------------------------------===//
227// Operation conversion
228//===----------------------------------------------------------------------===//
229
230// Note that DRR cannot be used for the patterns in this file: we may need to
231// convert type along the way, which requires ConversionPattern. DRR generates
232// normal RewritePattern.
233
234namespace {
235
236/// Converts memref.alloca to SPIR-V Function variables.
237class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
238public:
239 using Base::Base;
240
241 LogicalResult
242 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
243 ConversionPatternRewriter &rewriter) const override;
244};
245
246/// Converts an allocation operation to SPIR-V. Currently only supports lowering
247/// to Workgroup memory when the size is constant. Note that this pattern needs
248/// to be applied in a pass that runs at least at spirv.module scope since it
249/// wil ladd global variables into the spirv.module.
250class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
251public:
252 using Base::Base;
253
254 LogicalResult
255 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter) const override;
257};
258
259/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
260class AtomicRMWOpPattern final
261 : public OpConversionPattern<memref::AtomicRMWOp> {
262public:
263 using Base::Base;
264
265 LogicalResult
266 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
267 ConversionPatternRewriter &rewriter) const override;
268};
269
270/// Removed a deallocation if it is a supported allocation. Currently only
271/// removes deallocation if the memory space is workgroup memory.
272class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
273public:
274 using Base::Base;
275
276 LogicalResult
277 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter) const override;
279};
280
281/// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
282class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
283public:
284 using Base::Base;
285
286 LogicalResult
287 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override;
289};
290
291/// Converts memref.load to spirv.Load + spirv.AccessChain.
292class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
293public:
294 using Base::Base;
295
296 LogicalResult
297 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter) const override;
299};
300
301/// Converts memref.load to spirv.Image + spirv.ImageFetch
302class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
303public:
304 using Base::Base;
305
306 LogicalResult
307 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter) const override;
309};
310
311/// Converts memref.store to spirv.Store on integers.
312class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
313public:
314 using Base::Base;
315
316 LogicalResult
317 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
318 ConversionPatternRewriter &rewriter) const override;
319};
320
321/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
322class MemorySpaceCastOpPattern final
323 : public OpConversionPattern<memref::MemorySpaceCastOp> {
324public:
325 using Base::Base;
326
327 LogicalResult
328 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter) const override;
330};
331
332/// Converts memref.store to spirv.Store.
333class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
334public:
335 using Base::Base;
336
337 LogicalResult
338 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter) const override;
340};
341
342class ReinterpretCastPattern final
343 : public OpConversionPattern<memref::ReinterpretCastOp> {
344public:
345 using Base::Base;
346
347 LogicalResult
348 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override;
350};
351
352class CastPattern final : public OpConversionPattern<memref::CastOp> {
353public:
354 using Base::Base;
355
356 LogicalResult
357 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const override {
359 Value src = adaptor.getSource();
360 Type srcType = src.getType();
361
362 const TypeConverter *converter = getTypeConverter();
363 Type dstType = converter->convertType(op.getType());
364 if (srcType != dstType)
365 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
366 diag << "types doesn't match: " << srcType << " and " << dstType;
367 });
368
369 rewriter.replaceOp(op, src);
370 return success();
371 }
372};
373
374/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
375class ExtractAlignedPointerAsIndexOpPattern final
376 : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
377public:
378 using Base::Base;
379
380 LogicalResult
381 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
382 OpAdaptor adaptor,
383 ConversionPatternRewriter &rewriter) const override;
384};
385} // namespace
386
387//===----------------------------------------------------------------------===//
388// AllocaOp
389//===----------------------------------------------------------------------===//
390
391LogicalResult
392AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
393 ConversionPatternRewriter &rewriter) const {
394 MemRefType allocType = allocaOp.getType();
395 if (!isAllocationSupported(allocaOp, allocType))
396 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
397
398 // Get the SPIR-V type for the allocation.
399 Type spirvType = getTypeConverter()->convertType(allocType);
400 if (!spirvType)
401 return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
402
403 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
404 spirv::StorageClass::Function,
405 /*initializer=*/nullptr);
406 return success();
407}
408
409//===----------------------------------------------------------------------===//
410// AllocOp
411//===----------------------------------------------------------------------===//
412
413LogicalResult
414AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
415 ConversionPatternRewriter &rewriter) const {
416 MemRefType allocType = operation.getType();
417 if (!isAllocationSupported(operation, allocType))
418 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
419
420 // Get the SPIR-V type for the allocation.
421 Type spirvType = getTypeConverter()->convertType(allocType);
422 if (!spirvType)
423 return rewriter.notifyMatchFailure(operation, "type conversion failed");
424
425 // Insert spirv.GlobalVariable for this allocation.
426 Operation *parent =
427 SymbolTable::getNearestSymbolTable(operation->getParentOp());
428 if (!parent)
429 return failure();
430 Location loc = operation.getLoc();
431 spirv::GlobalVariableOp varOp;
432 {
433 OpBuilder::InsertionGuard guard(rewriter);
434 Block &entryBlock = *parent->getRegion(0).begin();
435 rewriter.setInsertionPointToStart(&entryBlock);
436 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
437 std::string varName =
438 std::string("__workgroup_mem__") +
439 std::to_string(std::distance(varOps.begin(), varOps.end()));
440 varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
441 /*initializer=*/nullptr);
442 }
443
444 // Get pointer to global variable at the current scope.
445 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
446 return success();
447}
448
449//===----------------------------------------------------------------------===//
450// AllocOp
451//===----------------------------------------------------------------------===//
452
453LogicalResult
454AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
455 OpAdaptor adaptor,
456 ConversionPatternRewriter &rewriter) const {
457 if (isa<FloatType>(atomicOp.getType()))
458 return rewriter.notifyMatchFailure(atomicOp,
459 "unimplemented floating-point case");
460
461 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
462 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
463 if (!scope)
464 return rewriter.notifyMatchFailure(atomicOp,
465 "unsupported memref memory space");
466
467 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
468 Type resultType = typeConverter.convertType(atomicOp.getType());
469 if (!resultType)
470 return rewriter.notifyMatchFailure(atomicOp,
471 "failed to convert result type");
472
473 auto loc = atomicOp.getLoc();
474 Value ptr =
475 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
476 adaptor.getIndices(), loc, rewriter);
477
478 if (!ptr)
479 return failure();
480
481 // Determine the source and destination bitwidths. The source is the original
482 // memref element type and the destination is the SPIR-V storage type (e.g.,
483 // i32 for Vulkan).
484 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
485 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
486 if (!pointerType)
487 return rewriter.notifyMatchFailure(atomicOp,
488 "failed to convert memref type");
489
490 Type pointeeType = pointerType.getPointeeType();
491 auto dstType = dyn_cast<IntegerType>(
492 getElementTypeForStoragePointer(pointeeType, typeConverter));
493 if (!dstType)
494 return rewriter.notifyMatchFailure(
495 atomicOp, "failed to determine destination element type");
496
497 int dstBits = static_cast<int>(dstType.getWidth());
498 assert(dstBits % srcBits == 0);
499
500 spirv::MemorySemantics memSem = getAtomicAcqRelMemorySemantics(memrefType);
501
502 // When the source and destination bitwidths match, emit the atomic operation
503 // directly.
504 if (srcBits == dstBits) {
505#define ATOMIC_CASE(kind, spirvOp) \
506 case arith::AtomicRMWKind::kind: \
507 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
508 atomicOp, resultType, ptr, *scope, memSem, adaptor.getValue()); \
509 break
510
511 switch (atomicOp.getKind()) {
512 ATOMIC_CASE(addi, AtomicIAddOp);
513 ATOMIC_CASE(maxs, AtomicSMaxOp);
514 ATOMIC_CASE(maxu, AtomicUMaxOp);
515 ATOMIC_CASE(mins, AtomicSMinOp);
516 ATOMIC_CASE(minu, AtomicUMinOp);
517 ATOMIC_CASE(ori, AtomicOrOp);
518 ATOMIC_CASE(andi, AtomicAndOp);
519 default:
520 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
521 }
522
523#undef ATOMIC_CASE
524
525 return success();
526 }
527
528 // Sub-element-width atomic: the element type (e.g., i8) is narrower than the
529 // storage type (e.g., i32). We need to adjust the index and shift/mask the
530 // value to operate on the correct bits within the wider storage element.
531 //
532 // Only ori and andi can be emulated because they operate bitwise and don't
533 // carry across byte boundaries. Other kinds (addi, max, min) would require
534 // CAS loops.
535 if (atomicOp.getKind() != arith::AtomicRMWKind::ori &&
536 atomicOp.getKind() != arith::AtomicRMWKind::andi) {
537 return rewriter.notifyMatchFailure(
538 atomicOp,
539 "atomic op on sub-element-width types is only supported for ori/andi");
540 }
541
542 // Bitcasting is currently unsupported for Kernel capability /
543 // spirv.PtrAccessChain.
544 if (typeConverter.allows(spirv::Capability::Kernel))
545 return rewriter.notifyMatchFailure(
546 atomicOp,
547 "sub-element-width atomic ops unsupported with Kernel capability");
548
549 auto accessChainOp = ptr.getDefiningOp<spirv::AccessChainOp>();
550 if (!accessChainOp)
551 return failure();
552
553 // Compute the bit offset within the storage element and adjust the pointer
554 // to address the containing storage element.
555 assert(accessChainOp.getIndices().size() == 2);
556 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
557 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
558 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
559 srcBits, dstBits, rewriter);
560 Value result;
561 switch (atomicOp.getKind()) {
562 case arith::AtomicRMWKind::ori: {
563 // OR only sets bits, so shifting the value to the target position and
564 // ORing with zeros in other positions preserves the unaffected bits.
565 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
566 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
567 Value storeVal =
568 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
569 result = spirv::AtomicOrOp::create(rewriter, loc, dstType, adjustedPtr,
570 *scope, memSem, storeVal);
571 break;
572 }
573 case arith::AtomicRMWKind::andi: {
574 // Build a mask that preserves all bits outside the target element
575 // and applies the operand mask to the target element.
576 // mask = (operand << offset) | ~(elemMask << offset)
577 Value elemMask = rewriter.createOrFold<spirv::ConstantOp>(
578 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
579 Value storeVal =
580 shiftValue(loc, adaptor.getValue(), offset, elemMask, rewriter);
581 Value shiftedElemMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
582 loc, dstType, elemMask, offset);
583 Value invertedElemMask =
584 rewriter.createOrFold<spirv::NotOp>(loc, dstType, shiftedElemMask);
585 Value mask = rewriter.createOrFold<spirv::BitwiseOrOp>(loc, storeVal,
586 invertedElemMask);
587 result = spirv::AtomicAndOp::create(rewriter, loc, dstType, adjustedPtr,
588 *scope, memSem, mask);
589 break;
590 }
591 default:
592 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
593 }
594
595 // The atomic op returns the old value of the full storage element (e.g.,
596 // i32). Extract the original sub-element value from the correct position.
597 result = rewriter.createOrFold<spirv::ShiftRightLogicalOp>(loc, dstType,
598 result, offset);
599 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
600 loc, dstType, rewriter.getIntegerAttr(dstType, (1uLL << srcBits) - 1));
601 result =
602 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
603 rewriter.replaceOp(atomicOp, result);
604
605 return success();
606}
607
608//===----------------------------------------------------------------------===//
609// DeallocOp
610//===----------------------------------------------------------------------===//
611
612LogicalResult
613DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
614 OpAdaptor adaptor,
615 ConversionPatternRewriter &rewriter) const {
616 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
617 if (!isAllocationSupported(operation, deallocType))
618 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
619 rewriter.eraseOp(operation);
620 return success();
621}
622
623//===----------------------------------------------------------------------===//
624// LoadOp
625//===----------------------------------------------------------------------===//
626
628 spirv::MemoryAccessAttr memoryAccess;
629 IntegerAttr alignment;
630};
631
632/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
633/// any.
634static FailureOr<MemoryRequirements>
635calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
636 uint64_t preferredAlignment) {
637 if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
638 return failure();
639 }
640
641 MLIRContext *ctx = accessedPtr.getContext();
642
643 auto memoryAccess = spirv::MemoryAccess::None;
644 if (isNontemporal) {
645 memoryAccess = spirv::MemoryAccess::Nontemporal;
646 }
647
648 auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
649 bool mayOmitAlignment =
650 !preferredAlignment &&
651 ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
652 if (mayOmitAlignment) {
653 if (memoryAccess == spirv::MemoryAccess::None) {
654 return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
655 }
656 return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
657 IntegerAttr{}};
658 }
659
660 // PhysicalStorageBuffers require the `Aligned` attribute.
661 // Other storage types may show an `Aligned` attribute.
662 std::optional<int64_t> sizeInBytes;
663 Type rawPointeeType = ptrType.getPointeeType();
664 if (auto scalarType = dyn_cast<spirv::ScalarType>(rawPointeeType)) {
665 // For scalar types, the alignment is determined by their size.
666 sizeInBytes = scalarType.getSizeInBytes();
667 } else if (auto vecType = dyn_cast<VectorType>(rawPointeeType)) {
668 // For vector element types, the alignment should equal the total size of
669 // the vector.
670 if (auto scalarElem =
671 dyn_cast<spirv::ScalarType>(vecType.getElementType())) {
672 if (auto elemSize = scalarElem.getSizeInBytes())
673 sizeInBytes = *elemSize * vecType.getNumElements();
674 }
675 }
676
677 if (!sizeInBytes.has_value())
678 return failure();
679
680 memoryAccess |= spirv::MemoryAccess::Aligned;
681 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
682 auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
683 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
684 return MemoryRequirements{memAccessAttr, alignment};
685}
686
687/// Given an accessed SPIR-V pointer and the original memref load/store
688/// `memAccess` op, calculates the alignment requirements, if any. Takes into
689/// account the alignment attributes applied to the load/store op.
690template <class LoadOrStoreOp>
691static FailureOr<MemoryRequirements>
692calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
693 static_assert(
694 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
695 "Must be called on either memref::LoadOp or memref::StoreOp");
696
697 return calculateMemoryRequirements(accessedPtr,
698 loadOrStoreOp.getNontemporal(),
699 loadOrStoreOp.getAlignment().value_or(0));
700}
701
702LogicalResult
703IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
704 ConversionPatternRewriter &rewriter) const {
705 auto loc = loadOp.getLoc();
706 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
707 if (!memrefType.getElementType().isSignlessInteger())
708 return failure();
709
710 auto memorySpaceAttr =
711 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
712 if (!memorySpaceAttr)
713 return rewriter.notifyMatchFailure(
714 loadOp, "missing memory space SPIR-V storage class attribute");
715
716 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
717 return rewriter.notifyMatchFailure(
718 loadOp,
719 "failed to lower memref in image storage class to storage buffer");
720
721 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
722 Value accessChain =
723 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
724 adaptor.getIndices(), loc, rewriter);
725
726 if (!accessChain)
727 return failure();
728
729 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
730 bool isBool = srcBits == 1;
731 if (isBool)
732 srcBits = typeConverter.getOptions().boolNumBits;
733
734 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
735 if (!pointerType)
736 return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
737
738 Type pointeeType = pointerType.getPointeeType();
739 Type dstType = getElementTypeForStoragePointer(pointeeType, typeConverter);
740 int dstBits = dstType.getIntOrFloatBitWidth();
741 assert(dstBits % srcBits == 0);
742
743 // If the rewritten load op has the same bit width, use the loading value
744 // directly.
745 if (srcBits == dstBits) {
746 auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
747 if (failed(memoryRequirements))
748 return rewriter.notifyMatchFailure(
749 loadOp, "failed to determine memory requirements");
750
751 auto [memoryAccess, alignment] = *memoryRequirements;
752 Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
753 memoryAccess, alignment);
754 if (isBool)
755 loadVal = castIntNToBool(loc, loadVal, rewriter);
756 rewriter.replaceOp(loadOp, loadVal);
757 return success();
758 }
759
760 // Bitcasting is currently unsupported for Kernel capability /
761 // spirv.PtrAccessChain.
762 if (typeConverter.allows(spirv::Capability::Kernel))
763 return failure();
764
765 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
766 if (!accessChainOp)
767 return failure();
768
769 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
770 // still returns a linearized accessing. If the accessing is not linearized,
771 // there will be offset issues.
772 assert(accessChainOp.getIndices().size() == 2);
773 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
774 srcBits, dstBits, rewriter);
775 auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
776 if (failed(memoryRequirements))
777 return rewriter.notifyMatchFailure(
778 loadOp, "failed to determine memory requirements");
779
780 auto [memoryAccess, alignment] = *memoryRequirements;
781 Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
782 memoryAccess, alignment);
783
784 // Shift the bits to the rightmost.
785 // ____XXXX________ -> ____________XXXX
786 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
787 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
788 Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
789 loc, spvLoadOp.getType(), spvLoadOp, offset);
790
791 // Apply the mask to extract corresponding bits.
792 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
793 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
794 result =
795 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
796
797 // Apply sign extension on the loading value unconditionally. The signedness
798 // semantic is carried in the operator itself, we relies other pattern to
799 // handle the casting.
800 IntegerAttr shiftValueAttr =
801 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
802 Value shiftValue =
803 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
804 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
806 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
807 loc, dstType, result, shiftValue);
808
809 rewriter.replaceOp(loadOp, result);
810
811 assert(accessChainOp.use_empty());
812 rewriter.eraseOp(accessChainOp);
813
814 return success();
815}
816
817LogicalResult
818LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
819 ConversionPatternRewriter &rewriter) const {
820 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
821 if (memrefType.getElementType().isSignlessInteger())
822 return failure();
823
824 auto memorySpaceAttr =
825 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
826 if (!memorySpaceAttr)
827 return rewriter.notifyMatchFailure(
828 loadOp, "missing memory space SPIR-V storage class attribute");
829
830 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
831 return rewriter.notifyMatchFailure(
832 loadOp,
833 "failed to lower memref in image storage class to storage buffer");
834
835 Value loadPtr = spirv::getElementPtr(
836 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
837 adaptor.getIndices(), loadOp.getLoc(), rewriter);
838
839 if (!loadPtr)
840 return failure();
841
842 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
843 if (failed(memoryRequirements))
844 return rewriter.notifyMatchFailure(
845 loadOp, "failed to determine memory requirements");
846
847 auto [memoryAccess, alignment] = *memoryRequirements;
848 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
849 alignment);
850 return success();
851}
852
853template <typename OpAdaptor>
854static FailureOr<SmallVector<Value>>
855extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor,
856 ConversionPatternRewriter &rewriter) {
857 // At present we only support linear "tiling" as specified in Vulkan, this
858 // means that texels are assumed to be laid out in memory in a row-major
859 // order. This allows us to support any memref layout that is a permutation of
860 // the dimensions. Future work will pass an optional image layout to the
861 // rewrite pattern so that we can support optimized target specific tilings.
862 SmallVector<Value> indices = adaptor.getIndices();
863 AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
864 if (!map.isPermutation())
865 return rewriter.notifyMatchFailure(
866 loadOp,
867 "Cannot lower memrefs with memory layout which is not a permutation");
868
869 // The memrefs layout determines the dimension ordering so we need to follow
870 // the map to get the ordering of the dimensions/indices.
871 const unsigned dimCount = map.getNumDims();
872 SmallVector<Value, 3> coords(dimCount);
873 for (unsigned dim = 0; dim < dimCount; ++dim)
874 coords[map.getDimPosition(dim)] = indices[dim];
875
876 // We need to reverse the coordinates because the memref layout is slowest to
877 // fastest moving and the vector coordinates for the image op is fastest to
878 // slowest moving.
879 return llvm::to_vector(llvm::reverse(coords));
880}
881
882LogicalResult
883ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
884 ConversionPatternRewriter &rewriter) const {
885 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
886
887 auto memorySpaceAttr =
888 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
889 if (!memorySpaceAttr)
890 return rewriter.notifyMatchFailure(
891 loadOp, "missing memory space SPIR-V storage class attribute");
892
893 if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
894 return rewriter.notifyMatchFailure(
895 loadOp, "failed to lower memref in non-image storage class to image");
896
897 Value loadPtr = adaptor.getMemref();
898 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
899 if (failed(memoryRequirements))
900 return rewriter.notifyMatchFailure(
901 loadOp, "failed to determine memory requirements");
902
903 const auto [memoryAccess, alignment] = *memoryRequirements;
904
905 if (!loadOp.getMemRefType().hasRank())
906 return rewriter.notifyMatchFailure(
907 loadOp, "cannot lower unranked memrefs to SPIR-V images");
908
909 // We currently only support lowering of scalar memref elements to texels in
910 // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
911 // elements to texels in richer formats.
912 if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
913 return rewriter.notifyMatchFailure(
914 loadOp,
915 "cannot lower memrefs who's element type is not a SPIR-V scalar type"
916 "to SPIR-V images");
917
918 // We currently only support sampled images since OpImageFetch does not work
919 // for plain images and the OpImageRead instruction needs to be materialized
920 // instead or texels need to be accessed via atomics through a texel pointer.
921 // Future work will generalize support to plain images.
922 auto convertedPointeeType = cast<spirv::PointerType>(
923 getTypeConverter()->convertType(loadOp.getMemRefType()));
924 if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
925 return rewriter.notifyMatchFailure(loadOp,
926 "cannot lower memrefs which do not "
927 "convert to SPIR-V sampled images");
928
929 // Materialize the lowering.
930 Location loc = loadOp->getLoc();
931 auto imageLoadOp =
932 spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
933 // Extract the image from the sampled image.
934 auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
935
936 // Build a vector of coordinates or just a scalar index if we have a 1D image.
937 Value coords;
938 if (memrefType.getRank() == 1) {
939 coords = adaptor.getIndices()[0];
940 } else {
941 FailureOr<SmallVector<Value>> maybeCoords =
942 extractLoadCoordsForComposite(loadOp, adaptor, rewriter);
943 if (failed(maybeCoords))
944 return failure();
945 auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
946 adaptor.getIndices().getType()[0]);
947 coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
948 maybeCoords.value());
949 }
950
951 // Fetch the value out of the image.
952 auto resultVectorType = VectorType::get({4}, loadOp.getType());
953 auto fetchOp = spirv::ImageFetchOp::create(
954 rewriter, loc, resultVectorType, imageOp, coords,
955 mlir::spirv::ImageOperandsAttr{}, ValueRange{});
956
957 // Note that because OpImageFetch returns a rank 4 vector we need to extract
958 // the elements corresponding to the load which will since we only support the
959 // R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
960 auto compositeExtractOp =
961 spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
962
963 rewriter.replaceOp(loadOp, compositeExtractOp);
964 return success();
965}
966
967LogicalResult
968IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
969 ConversionPatternRewriter &rewriter) const {
970 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
971 if (!memrefType.getElementType().isSignlessInteger())
972 return rewriter.notifyMatchFailure(storeOp,
973 "element type is not a signless int");
974
975 auto loc = storeOp.getLoc();
976 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
977 Value accessChain =
978 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
979 adaptor.getIndices(), loc, rewriter);
980
981 if (!accessChain)
982 return rewriter.notifyMatchFailure(
983 storeOp, "failed to convert element pointer type");
984
985 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
986
987 bool isBool = srcBits == 1;
988 if (isBool)
989 srcBits = typeConverter.getOptions().boolNumBits;
990
991 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
992 if (!pointerType)
993 return rewriter.notifyMatchFailure(storeOp,
994 "failed to convert memref type");
995
996 Type pointeeType = pointerType.getPointeeType();
997 auto dstType = dyn_cast<IntegerType>(
998 getElementTypeForStoragePointer(pointeeType, typeConverter));
999 if (!dstType)
1000 return rewriter.notifyMatchFailure(
1001 storeOp, "failed to determine destination element type");
1002
1003 int dstBits = static_cast<int>(dstType.getWidth());
1004 assert(dstBits % srcBits == 0);
1005
1006 if (srcBits == dstBits) {
1007 auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
1008 if (failed(memoryRequirements))
1009 return rewriter.notifyMatchFailure(
1010 storeOp, "failed to determine memory requirements");
1011
1012 auto [memoryAccess, alignment] = *memoryRequirements;
1013 Value storeVal = adaptor.getValue();
1014 if (isBool)
1015 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
1016 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
1017 memoryAccess, alignment);
1018 return success();
1019 }
1020
1021 // Bitcasting is currently unsupported for Kernel capability /
1022 // spirv.PtrAccessChain.
1023 if (typeConverter.allows(spirv::Capability::Kernel))
1024 return failure();
1025
1026 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
1027 if (!accessChainOp)
1028 return failure();
1029
1030 // Since there are multiple threads in the processing, the emulation will be
1031 // done with atomic operations. E.g., if the stored value is i8, rewrite the
1032 // StoreOp to:
1033 // 1) load a 32-bit integer
1034 // 2) clear 8 bits in the loaded value
1035 // 3) set 8 bits in the loaded value
1036 // 4) store 32-bit value back
1037 //
1038 // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
1039 // loaded 32-bit value and the shifted 8-bit store value) as another atomic
1040 // step.
1041 assert(accessChainOp.getIndices().size() == 2);
1042 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
1043 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
1044
1045 // Create a mask to clear the destination. E.g., if it is the second i8 in
1046 // i32, 0xFFFF00FF is created.
1047 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
1048 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
1049 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
1050 loc, dstType, mask, offset);
1051 clearBitsMask =
1052 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
1053
1054 Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
1055 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
1056 srcBits, dstBits, rewriter);
1057 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
1058 if (!scope)
1059 return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
1060
1061 spirv::MemorySemantics memSem = getAtomicAcqRelMemorySemantics(memrefType);
1062 Value result = spirv::AtomicAndOp::create(rewriter, loc, dstType, adjustedPtr,
1063 *scope, memSem, clearBitsMask);
1064 result = spirv::AtomicOrOp::create(rewriter, loc, dstType, adjustedPtr,
1065 *scope, memSem, storeVal);
1066
1067 // The AtomicOrOp has no side effect. Since it is already inserted, we can
1068 // just remove the original StoreOp. Note that rewriter.replaceOp()
1069 // doesn't work because it only accepts that the numbers of result are the
1070 // same.
1071 rewriter.eraseOp(storeOp);
1072
1073 assert(accessChainOp.use_empty());
1074 rewriter.eraseOp(accessChainOp);
1075
1076 return success();
1077}
1078
1079//===----------------------------------------------------------------------===//
1080// MemorySpaceCastOp
1081//===----------------------------------------------------------------------===//
1082
1083LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
1084 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
1085 ConversionPatternRewriter &rewriter) const {
1086 Location loc = addrCastOp.getLoc();
1087 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1088 if (!typeConverter.allows(spirv::Capability::Kernel))
1089 return rewriter.notifyMatchFailure(
1090 loc, "address space casts require kernel capability");
1091
1092 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
1093 if (!sourceType)
1094 return rewriter.notifyMatchFailure(
1095 loc, "SPIR-V lowering requires ranked memref types");
1096 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
1097
1098 auto sourceStorageClassAttr =
1099 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
1100 if (!sourceStorageClassAttr)
1101 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
1102 diag << "source address space " << sourceType.getMemorySpace()
1103 << " must be a SPIR-V storage class";
1104 });
1105 auto resultStorageClassAttr =
1106 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
1107 if (!resultStorageClassAttr)
1108 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
1109 diag << "result address space " << resultType.getMemorySpace()
1110 << " must be a SPIR-V storage class";
1111 });
1112
1113 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
1114 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
1115
1116 Value result = adaptor.getSource();
1117 Type resultPtrType = typeConverter.convertType(resultType);
1118 if (!resultPtrType)
1119 return rewriter.notifyMatchFailure(addrCastOp,
1120 "failed to convert memref type");
1121
1122 Type genericPtrType = resultPtrType;
1123 // SPIR-V doesn't have a general address space cast operation. Instead, it has
1124 // conversions to and from generic pointers. To implement the general case,
1125 // we use specific-to-generic conversions when the source class is not
1126 // generic. Then when the result storage class is not generic, we convert the
1127 // generic pointer (either the input on ar intermediate result) to that
1128 // class. This also means that we'll need the intermediate generic pointer
1129 // type if neither the source or destination have it.
1130 if (sourceSc != spirv::StorageClass::Generic &&
1131 resultSc != spirv::StorageClass::Generic) {
1132 Type intermediateType =
1133 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
1134 sourceType.getLayout(),
1135 rewriter.getAttr<spirv::StorageClassAttr>(
1136 spirv::StorageClass::Generic));
1137 genericPtrType = typeConverter.convertType(intermediateType);
1138 }
1139 if (sourceSc != spirv::StorageClass::Generic) {
1140 result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
1141 result);
1142 }
1143 if (resultSc != spirv::StorageClass::Generic) {
1144 result =
1145 spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result);
1146 }
1147 rewriter.replaceOp(addrCastOp, result);
1148 return success();
1149}
1150
1151LogicalResult
1152StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
1153 ConversionPatternRewriter &rewriter) const {
1154 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
1155 if (memrefType.getElementType().isSignlessInteger())
1156 return rewriter.notifyMatchFailure(storeOp, "signless int");
1157 auto storePtr = spirv::getElementPtr(
1158 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
1159 adaptor.getIndices(), storeOp.getLoc(), rewriter);
1160
1161 if (!storePtr)
1162 return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
1163
1164 auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
1165 if (failed(memoryRequirements))
1166 return rewriter.notifyMatchFailure(
1167 storeOp, "failed to determine memory requirements");
1168
1169 auto [memoryAccess, alignment] = *memoryRequirements;
1170 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1171 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1172 return success();
1173}
1174
1175LogicalResult ReinterpretCastPattern::matchAndRewrite(
1176 memref::ReinterpretCastOp op, OpAdaptor adaptor,
1177 ConversionPatternRewriter &rewriter) const {
1178 Value src = adaptor.getSource();
1179 auto srcType = dyn_cast<spirv::PointerType>(src.getType());
1180
1181 if (!srcType)
1182 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1183 diag << "invalid src type " << src.getType();
1184 });
1185
1186 const TypeConverter *converter = getTypeConverter();
1187
1188 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1189 if (dstType != srcType)
1190 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1191 diag << "invalid dst type " << op.getType();
1192 });
1193
1194 OpFoldResult offset =
1195 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1196 .front();
1197 if (isZeroInteger(offset)) {
1198 rewriter.replaceOp(op, src);
1199 return success();
1200 }
1201
1202 Type intType = converter->convertType(rewriter.getIndexType());
1203 if (!intType)
1204 return rewriter.notifyMatchFailure(op, "failed to convert index type");
1205
1206 Location loc = op.getLoc();
1207 auto offsetValue = [&]() -> Value {
1208 if (auto val = dyn_cast<Value>(offset))
1209 return val;
1210
1211 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1212 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1213 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1214 }();
1215
1216 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1217 op, src, offsetValue, ValueRange());
1218 return success();
1219}
1220
1221//===----------------------------------------------------------------------===//
1222// ExtractAlignedPointerAsIndexOp
1223//===----------------------------------------------------------------------===//
1224
1225LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1226 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1227 ConversionPatternRewriter &rewriter) const {
1228 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1229 Type indexType = typeConverter.getIndexType();
1230 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1231 adaptor.getSource());
1232 return success();
1233}
1234
1235//===----------------------------------------------------------------------===//
1236// Pattern population
1237//===----------------------------------------------------------------------===//
1238
1239namespace mlir {
1241 RewritePatternSet &patterns) {
1242 patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1243 DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1244 IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1245 StoreOpPattern, ReinterpretCastPattern, CastPattern,
1246 ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
1247 patterns.getContext());
1248}
1249} // namespace mlir
return success()
static spirv::MemorySemantics getMemorySemanticsForStorageClass(spirv::StorageClass sc)
Returns the MemorySemantics storage-class bit corresponding to sc.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static Type getElementTypeForStoragePointer(Type pointeeType, const SPIRVTypeConverter &typeConverter)
Extracts the element type from a SPIR-V pointer type pointing to storage.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static FailureOr< SmallVector< Value > > extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
static spirv::MemorySemantics getAtomicAcqRelMemorySemantics(MemRefType type)
Returns the AcquireRelease memory semantics OR'd with the storage-class bit derived from the memory s...
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
iterator begin()
Definition Region.h:55
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess