MLIR 22.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Visitors.h"
24#include <cassert>
25#include <limits>
26#include <optional>
27
28#define DEBUG_TYPE "memref-to-spirv-pattern"
29
30using namespace mlir;
31
32//===----------------------------------------------------------------------===//
33// Utility functions
34//===----------------------------------------------------------------------===//
35
36/// Returns the offset of the value in `targetBits` representation.
37///
38/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39/// It's assumed to be non-negative.
40///
41/// When accessing an element in the array treating as having elements of
42/// `targetBits`, multiple values are loaded in the same time. The method
43/// returns the offset where the `srcIdx` locates in the value. For example, if
44/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45/// located at (x % 4) * 8. Because there are four elements in one i32, and one
46/// element has 8 bits.
47static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
48 int targetBits, OpBuilder &builder) {
49 assert(targetBits % sourceBits == 0);
50 Type type = srcIdx.getType();
51 IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54 auto srcBitsValue =
55 builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
58}
59
60/// Returns an adjusted spirv::AccessChainOp. Based on the
61/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62/// supported. During conversion if a memref of an unsupported type is used,
63/// load/stores to this memref need to be modified to use a supported higher
64/// bitwidth `targetBits` and extracting the required bits. For an accessing a
65/// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66/// the bits needed. The extraction of the actual bits needed are handled
67/// separately. Note that this only works for a 1-D tensor.
68static Value
70 spirv::AccessChainOp op, int sourceBits,
71 int targetBits, OpBuilder &builder) {
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
75 Type type = lastDim.getType();
76 IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
79 // There are two elements if this is a 1-D tensor.
80 assert(indices.size() == 2);
81 indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
84 indices);
85}
86
87/// Casts the given `srcBool` into an integer of `dstType`.
88static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
89 OpBuilder &builder) {
90 assert(srcBool.getType().isInteger(1));
91 if (dstType.isInteger(1))
92 return srcBool;
93 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
96 zero);
97}
98
99/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
100/// to the type destination type, and masked.
101static Value shiftValue(Location loc, Value value, Value offset, Value mask,
102 OpBuilder &builder) {
103 IntegerType dstType = cast<IntegerType>(mask.getType());
104 int targetBits = static_cast<int>(dstType.getWidth());
105 int valueBits = value.getType().getIntOrFloatBitWidth();
106 assert(valueBits <= targetBits);
107
108 if (valueBits == 1) {
109 value = castBoolToIntN(loc, value, dstType, builder);
110 } else {
111 if (valueBits < targetBits) {
112 value = spirv::UConvertOp::create(
113 builder, loc, builder.getIntegerType(targetBits), value);
114 }
115
116 value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
117 }
118 return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
119 value, offset);
120}
121
122/// Returns true if the allocations of memref `type` generated from `allocOp`
123/// can be lowered to SPIR-V.
124static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
128 return false;
129 } else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
132 return false;
133 } else {
134 return false;
135 }
136
137 // Currently only support static shape and int or float or vector of int or
138 // float element type.
139 if (!type.hasStaticShape())
140 return false;
141
142 Type elementType = type.getElementType();
143 if (auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
145 return elementType.isIntOrFloat();
146}
147
148/// Returns the scope to use for atomic operations use for emulating store
149/// operations of unsupported integer bitwidths, based on the memref
150/// type. Returns std::nullopt on failure.
151static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
152 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
153 switch (sc.getValue()) {
154 case spirv::StorageClass::StorageBuffer:
155 return spirv::Scope::Device;
156 case spirv::StorageClass::Workgroup:
157 return spirv::Scope::Workgroup;
158 default:
159 break;
160 }
161 return {};
162}
163
164/// Casts the given `srcInt` into a boolean value.
165static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
166 if (srcInt.getType().isInteger(1))
167 return srcInt;
168
169 auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
170 return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
171}
172
173//===----------------------------------------------------------------------===//
174// Operation conversion
175//===----------------------------------------------------------------------===//
176
177// Note that DRR cannot be used for the patterns in this file: we may need to
178// convert type along the way, which requires ConversionPattern. DRR generates
179// normal RewritePattern.
180
181namespace {
182
183/// Converts memref.alloca to SPIR-V Function variables.
184class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
185public:
186 using Base::Base;
187
188 LogicalResult
189 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
190 ConversionPatternRewriter &rewriter) const override;
191};
192
193/// Converts an allocation operation to SPIR-V. Currently only supports lowering
194/// to Workgroup memory when the size is constant. Note that this pattern needs
195/// to be applied in a pass that runs at least at spirv.module scope since it
196/// wil ladd global variables into the spirv.module.
197class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
198public:
199 using Base::Base;
200
201 LogicalResult
202 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter) const override;
204};
205
206/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
207class AtomicRMWOpPattern final
208 : public OpConversionPattern<memref::AtomicRMWOp> {
209public:
210 using Base::Base;
211
212 LogicalResult
213 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
214 ConversionPatternRewriter &rewriter) const override;
215};
216
217/// Removed a deallocation if it is a supported allocation. Currently only
218/// removes deallocation if the memory space is workgroup memory.
219class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
220public:
221 using Base::Base;
222
223 LogicalResult
224 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
225 ConversionPatternRewriter &rewriter) const override;
226};
227
228/// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
229class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
230public:
231 using Base::Base;
232
233 LogicalResult
234 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
235 ConversionPatternRewriter &rewriter) const override;
236};
237
238/// Converts memref.load to spirv.Load + spirv.AccessChain.
239class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
240public:
241 using Base::Base;
242
243 LogicalResult
244 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const override;
246};
247
248/// Converts memref.load to spirv.Image + spirv.ImageFetch
249class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
250public:
251 using Base::Base;
252
253 LogicalResult
254 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
255 ConversionPatternRewriter &rewriter) const override;
256};
257
258/// Converts memref.store to spirv.Store on integers.
259class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
260public:
261 using Base::Base;
262
263 LogicalResult
264 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter) const override;
266};
267
268/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
269class MemorySpaceCastOpPattern final
270 : public OpConversionPattern<memref::MemorySpaceCastOp> {
271public:
272 using Base::Base;
273
274 LogicalResult
275 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
276 ConversionPatternRewriter &rewriter) const override;
277};
278
279/// Converts memref.store to spirv.Store.
280class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
281public:
282 using Base::Base;
283
284 LogicalResult
285 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter) const override;
287};
288
289class ReinterpretCastPattern final
290 : public OpConversionPattern<memref::ReinterpretCastOp> {
291public:
292 using Base::Base;
293
294 LogicalResult
295 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
296 ConversionPatternRewriter &rewriter) const override;
297};
298
299class CastPattern final : public OpConversionPattern<memref::CastOp> {
300public:
301 using Base::Base;
302
303 LogicalResult
304 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter) const override {
306 Value src = adaptor.getSource();
307 Type srcType = src.getType();
308
309 const TypeConverter *converter = getTypeConverter();
310 Type dstType = converter->convertType(op.getType());
311 if (srcType != dstType)
312 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
313 diag << "types doesn't match: " << srcType << " and " << dstType;
314 });
315
316 rewriter.replaceOp(op, src);
317 return success();
318 }
319};
320
321/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
322class ExtractAlignedPointerAsIndexOpPattern final
323 : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
324public:
325 using Base::Base;
326
327 LogicalResult
328 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
329 OpAdaptor adaptor,
330 ConversionPatternRewriter &rewriter) const override;
331};
332} // namespace
333
334//===----------------------------------------------------------------------===//
335// AllocaOp
336//===----------------------------------------------------------------------===//
337
338LogicalResult
339AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter) const {
341 MemRefType allocType = allocaOp.getType();
342 if (!isAllocationSupported(allocaOp, allocType))
343 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
344
345 // Get the SPIR-V type for the allocation.
346 Type spirvType = getTypeConverter()->convertType(allocType);
347 if (!spirvType)
348 return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
349
350 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
351 spirv::StorageClass::Function,
352 /*initializer=*/nullptr);
353 return success();
354}
355
356//===----------------------------------------------------------------------===//
357// AllocOp
358//===----------------------------------------------------------------------===//
359
360LogicalResult
361AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
362 ConversionPatternRewriter &rewriter) const {
363 MemRefType allocType = operation.getType();
364 if (!isAllocationSupported(operation, allocType))
365 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
366
367 // Get the SPIR-V type for the allocation.
368 Type spirvType = getTypeConverter()->convertType(allocType);
369 if (!spirvType)
370 return rewriter.notifyMatchFailure(operation, "type conversion failed");
371
372 // Insert spirv.GlobalVariable for this allocation.
373 Operation *parent =
374 SymbolTable::getNearestSymbolTable(operation->getParentOp());
375 if (!parent)
376 return failure();
377 Location loc = operation.getLoc();
378 spirv::GlobalVariableOp varOp;
379 {
380 OpBuilder::InsertionGuard guard(rewriter);
381 Block &entryBlock = *parent->getRegion(0).begin();
382 rewriter.setInsertionPointToStart(&entryBlock);
383 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
384 std::string varName =
385 std::string("__workgroup_mem__") +
386 std::to_string(std::distance(varOps.begin(), varOps.end()));
387 varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
388 /*initializer=*/nullptr);
389 }
390
391 // Get pointer to global variable at the current scope.
392 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
393 return success();
394}
395
396//===----------------------------------------------------------------------===//
397// AllocOp
398//===----------------------------------------------------------------------===//
399
400LogicalResult
401AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
402 OpAdaptor adaptor,
403 ConversionPatternRewriter &rewriter) const {
404 if (isa<FloatType>(atomicOp.getType()))
405 return rewriter.notifyMatchFailure(atomicOp,
406 "unimplemented floating-point case");
407
408 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
409 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
410 if (!scope)
411 return rewriter.notifyMatchFailure(atomicOp,
412 "unsupported memref memory space");
413
414 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
415 Type resultType = typeConverter.convertType(atomicOp.getType());
416 if (!resultType)
417 return rewriter.notifyMatchFailure(atomicOp,
418 "failed to convert result type");
419
420 auto loc = atomicOp.getLoc();
421 Value ptr =
422 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
423 adaptor.getIndices(), loc, rewriter);
424
425 if (!ptr)
426 return failure();
427
428#define ATOMIC_CASE(kind, spirvOp) \
429 case arith::AtomicRMWKind::kind: \
430 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
431 atomicOp, resultType, ptr, *scope, \
432 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
433 break
434
435 switch (atomicOp.getKind()) {
436 ATOMIC_CASE(addi, AtomicIAddOp);
437 ATOMIC_CASE(maxs, AtomicSMaxOp);
438 ATOMIC_CASE(maxu, AtomicUMaxOp);
439 ATOMIC_CASE(mins, AtomicSMinOp);
440 ATOMIC_CASE(minu, AtomicUMinOp);
441 ATOMIC_CASE(ori, AtomicOrOp);
442 ATOMIC_CASE(andi, AtomicAndOp);
443 default:
444 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
445 }
446
447#undef ATOMIC_CASE
448
449 return success();
450}
451
452//===----------------------------------------------------------------------===//
453// DeallocOp
454//===----------------------------------------------------------------------===//
455
456LogicalResult
457DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
458 OpAdaptor adaptor,
459 ConversionPatternRewriter &rewriter) const {
460 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
461 if (!isAllocationSupported(operation, deallocType))
462 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
463 rewriter.eraseOp(operation);
464 return success();
465}
466
467//===----------------------------------------------------------------------===//
468// LoadOp
469//===----------------------------------------------------------------------===//
470
472 spirv::MemoryAccessAttr memoryAccess;
473 IntegerAttr alignment;
474};
475
476/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
477/// any.
478static FailureOr<MemoryRequirements>
479calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
480 uint64_t preferredAlignment) {
481 if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
482 return failure();
483 }
484
485 MLIRContext *ctx = accessedPtr.getContext();
486
487 auto memoryAccess = spirv::MemoryAccess::None;
488 if (isNontemporal) {
489 memoryAccess = spirv::MemoryAccess::Nontemporal;
490 }
491
492 auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
493 bool mayOmitAlignment =
494 !preferredAlignment &&
495 ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
496 if (mayOmitAlignment) {
497 if (memoryAccess == spirv::MemoryAccess::None) {
498 return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
499 }
500 return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
501 IntegerAttr{}};
502 }
503
504 // PhysicalStorageBuffers require the `Aligned` attribute.
505 // Other storage types may show an `Aligned` attribute.
506 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
507 if (!pointeeType)
508 return failure();
509
510 // For scalar types, the alignment is determined by their size.
511 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
512 if (!sizeInBytes.has_value())
513 return failure();
514
515 memoryAccess |= spirv::MemoryAccess::Aligned;
516 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
517 auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
518 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
519 return MemoryRequirements{memAccessAttr, alignment};
520}
521
522/// Given an accessed SPIR-V pointer and the original memref load/store
523/// `memAccess` op, calculates the alignment requirements, if any. Takes into
524/// account the alignment attributes applied to the load/store op.
525template <class LoadOrStoreOp>
526static FailureOr<MemoryRequirements>
527calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
528 static_assert(
529 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
530 "Must be called on either memref::LoadOp or memref::StoreOp");
531
532 return calculateMemoryRequirements(accessedPtr,
533 loadOrStoreOp.getNontemporal(),
534 loadOrStoreOp.getAlignment().value_or(0));
535}
536
537LogicalResult
538IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
539 ConversionPatternRewriter &rewriter) const {
540 auto loc = loadOp.getLoc();
541 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
542 if (!memrefType.getElementType().isSignlessInteger())
543 return failure();
544
545 auto memorySpaceAttr =
546 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
547 if (!memorySpaceAttr)
548 return rewriter.notifyMatchFailure(
549 loadOp, "missing memory space SPIR-V storage class attribute");
550
551 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
552 return rewriter.notifyMatchFailure(
553 loadOp,
554 "failed to lower memref in image storage class to storage buffer");
555
556 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
557 Value accessChain =
558 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
559 adaptor.getIndices(), loc, rewriter);
560
561 if (!accessChain)
562 return failure();
563
564 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
565 bool isBool = srcBits == 1;
566 if (isBool)
567 srcBits = typeConverter.getOptions().boolNumBits;
568
569 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
570 if (!pointerType)
571 return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
572
573 Type pointeeType = pointerType.getPointeeType();
574 Type dstType;
575 if (typeConverter.allows(spirv::Capability::Kernel)) {
576 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
577 dstType = arrayType.getElementType();
578 else
579 dstType = pointeeType;
580 } else {
581 // For Vulkan we need to extract element from wrapping struct and array.
582 Type structElemType =
583 cast<spirv::StructType>(pointeeType).getElementType(0);
584 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
585 dstType = arrayType.getElementType();
586 else
587 dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
588 }
589 int dstBits = dstType.getIntOrFloatBitWidth();
590 assert(dstBits % srcBits == 0);
591
592 // If the rewritten load op has the same bit width, use the loading value
593 // directly.
594 if (srcBits == dstBits) {
595 auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
596 if (failed(memoryRequirements))
597 return rewriter.notifyMatchFailure(
598 loadOp, "failed to determine memory requirements");
599
600 auto [memoryAccess, alignment] = *memoryRequirements;
601 Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
602 memoryAccess, alignment);
603 if (isBool)
604 loadVal = castIntNToBool(loc, loadVal, rewriter);
605 rewriter.replaceOp(loadOp, loadVal);
606 return success();
607 }
608
609 // Bitcasting is currently unsupported for Kernel capability /
610 // spirv.PtrAccessChain.
611 if (typeConverter.allows(spirv::Capability::Kernel))
612 return failure();
613
614 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
615 if (!accessChainOp)
616 return failure();
617
618 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
619 // still returns a linearized accessing. If the accessing is not linearized,
620 // there will be offset issues.
621 assert(accessChainOp.getIndices().size() == 2);
622 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
623 srcBits, dstBits, rewriter);
624 auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
625 if (failed(memoryRequirements))
626 return rewriter.notifyMatchFailure(
627 loadOp, "failed to determine memory requirements");
628
629 auto [memoryAccess, alignment] = *memoryRequirements;
630 Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
631 memoryAccess, alignment);
632
633 // Shift the bits to the rightmost.
634 // ____XXXX________ -> ____________XXXX
635 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
636 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
637 Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
638 loc, spvLoadOp.getType(), spvLoadOp, offset);
639
640 // Apply the mask to extract corresponding bits.
641 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
642 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
643 result =
644 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
645
646 // Apply sign extension on the loading value unconditionally. The signedness
647 // semantic is carried in the operator itself, we relies other pattern to
648 // handle the casting.
649 IntegerAttr shiftValueAttr =
650 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
651 Value shiftValue =
652 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
653 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
655 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
656 loc, dstType, result, shiftValue);
657
658 rewriter.replaceOp(loadOp, result);
659
660 assert(accessChainOp.use_empty());
661 rewriter.eraseOp(accessChainOp);
662
663 return success();
664}
665
666LogicalResult
667LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
668 ConversionPatternRewriter &rewriter) const {
669 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
670 if (memrefType.getElementType().isSignlessInteger())
671 return failure();
672
673 auto memorySpaceAttr =
674 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
675 if (!memorySpaceAttr)
676 return rewriter.notifyMatchFailure(
677 loadOp, "missing memory space SPIR-V storage class attribute");
678
679 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
680 return rewriter.notifyMatchFailure(
681 loadOp,
682 "failed to lower memref in image storage class to storage buffer");
683
684 Value loadPtr = spirv::getElementPtr(
685 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
686 adaptor.getIndices(), loadOp.getLoc(), rewriter);
687
688 if (!loadPtr)
689 return failure();
690
691 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
692 if (failed(memoryRequirements))
693 return rewriter.notifyMatchFailure(
694 loadOp, "failed to determine memory requirements");
695
696 auto [memoryAccess, alignment] = *memoryRequirements;
697 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
698 alignment);
699 return success();
700}
701
702template <typename OpAdaptor>
703static FailureOr<SmallVector<Value>>
704extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor,
705 ConversionPatternRewriter &rewriter) {
706 // At present we only support linear "tiling" as specified in Vulkan, this
707 // means that texels are assumed to be laid out in memory in a row-major
708 // order. This allows us to support any memref layout that is a permutation of
709 // the dimensions. Future work will pass an optional image layout to the
710 // rewrite pattern so that we can support optimized target specific tilings.
711 SmallVector<Value> indices = adaptor.getIndices();
712 AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
713 if (!map.isPermutation())
714 return rewriter.notifyMatchFailure(
715 loadOp,
716 "Cannot lower memrefs with memory layout which is not a permutation");
717
718 // The memrefs layout determines the dimension ordering so we need to follow
719 // the map to get the ordering of the dimensions/indices.
720 const unsigned dimCount = map.getNumDims();
721 SmallVector<Value, 3> coords(dimCount);
722 for (unsigned dim = 0; dim < dimCount; ++dim)
723 coords[map.getDimPosition(dim)] = indices[dim];
724
725 // We need to reverse the coordinates because the memref layout is slowest to
726 // fastest moving and the vector coordinates for the image op is fastest to
727 // slowest moving.
728 return llvm::to_vector(llvm::reverse(coords));
729}
730
731LogicalResult
732ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
733 ConversionPatternRewriter &rewriter) const {
734 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
735
736 auto memorySpaceAttr =
737 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
738 if (!memorySpaceAttr)
739 return rewriter.notifyMatchFailure(
740 loadOp, "missing memory space SPIR-V storage class attribute");
741
742 if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
743 return rewriter.notifyMatchFailure(
744 loadOp, "failed to lower memref in non-image storage class to image");
745
746 Value loadPtr = adaptor.getMemref();
747 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
748 if (failed(memoryRequirements))
749 return rewriter.notifyMatchFailure(
750 loadOp, "failed to determine memory requirements");
751
752 const auto [memoryAccess, alignment] = *memoryRequirements;
753
754 if (!loadOp.getMemRefType().hasRank())
755 return rewriter.notifyMatchFailure(
756 loadOp, "cannot lower unranked memrefs to SPIR-V images");
757
758 // We currently only support lowering of scalar memref elements to texels in
759 // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
760 // elements to texels in richer formats.
761 if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
762 return rewriter.notifyMatchFailure(
763 loadOp,
764 "cannot lower memrefs who's element type is not a SPIR-V scalar type"
765 "to SPIR-V images");
766
767 // We currently only support sampled images since OpImageFetch does not work
768 // for plain images and the OpImageRead instruction needs to be materialized
769 // instead or texels need to be accessed via atomics through a texel pointer.
770 // Future work will generalize support to plain images.
771 auto convertedPointeeType = cast<spirv::PointerType>(
772 getTypeConverter()->convertType(loadOp.getMemRefType()));
773 if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
774 return rewriter.notifyMatchFailure(loadOp,
775 "cannot lower memrefs which do not "
776 "convert to SPIR-V sampled images");
777
778 // Materialize the lowering.
779 Location loc = loadOp->getLoc();
780 auto imageLoadOp =
781 spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
782 // Extract the image from the sampled image.
783 auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
784
785 // Build a vector of coordinates or just a scalar index if we have a 1D image.
786 Value coords;
787 if (memrefType.getRank() == 1) {
788 coords = adaptor.getIndices()[0];
789 } else {
790 FailureOr<SmallVector<Value>> maybeCoords =
791 extractLoadCoordsForComposite(loadOp, adaptor, rewriter);
792 if (failed(maybeCoords))
793 return failure();
794 auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
795 adaptor.getIndices().getType()[0]);
796 coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
797 maybeCoords.value());
798 }
799
800 // Fetch the value out of the image.
801 auto resultVectorType = VectorType::get({4}, loadOp.getType());
802 auto fetchOp = spirv::ImageFetchOp::create(
803 rewriter, loc, resultVectorType, imageOp, coords,
804 mlir::spirv::ImageOperandsAttr{}, ValueRange{});
805
806 // Note that because OpImageFetch returns a rank 4 vector we need to extract
807 // the elements corresponding to the load which will since we only support the
808 // R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
809 auto compositeExtractOp =
810 spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
811
812 rewriter.replaceOp(loadOp, compositeExtractOp);
813 return success();
814}
815
816LogicalResult
817IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter) const {
819 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
820 if (!memrefType.getElementType().isSignlessInteger())
821 return rewriter.notifyMatchFailure(storeOp,
822 "element type is not a signless int");
823
824 auto loc = storeOp.getLoc();
825 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
826 Value accessChain =
827 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
828 adaptor.getIndices(), loc, rewriter);
829
830 if (!accessChain)
831 return rewriter.notifyMatchFailure(
832 storeOp, "failed to convert element pointer type");
833
834 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
835
836 bool isBool = srcBits == 1;
837 if (isBool)
838 srcBits = typeConverter.getOptions().boolNumBits;
839
840 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
841 if (!pointerType)
842 return rewriter.notifyMatchFailure(storeOp,
843 "failed to convert memref type");
844
845 Type pointeeType = pointerType.getPointeeType();
846 IntegerType dstType;
847 if (typeConverter.allows(spirv::Capability::Kernel)) {
848 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
849 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
850 else
851 dstType = dyn_cast<IntegerType>(pointeeType);
852 } else {
853 // For Vulkan we need to extract element from wrapping struct and array.
854 Type structElemType =
855 cast<spirv::StructType>(pointeeType).getElementType(0);
856 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
857 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
858 else
859 dstType = dyn_cast<IntegerType>(
860 cast<spirv::RuntimeArrayType>(structElemType).getElementType());
861 }
862
863 if (!dstType)
864 return rewriter.notifyMatchFailure(
865 storeOp, "failed to determine destination element type");
866
867 int dstBits = static_cast<int>(dstType.getWidth());
868 assert(dstBits % srcBits == 0);
869
870 if (srcBits == dstBits) {
871 auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
872 if (failed(memoryRequirements))
873 return rewriter.notifyMatchFailure(
874 storeOp, "failed to determine memory requirements");
875
876 auto [memoryAccess, alignment] = *memoryRequirements;
877 Value storeVal = adaptor.getValue();
878 if (isBool)
879 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
880 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
881 memoryAccess, alignment);
882 return success();
883 }
884
885 // Bitcasting is currently unsupported for Kernel capability /
886 // spirv.PtrAccessChain.
887 if (typeConverter.allows(spirv::Capability::Kernel))
888 return failure();
889
890 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
891 if (!accessChainOp)
892 return failure();
893
894 // Since there are multiple threads in the processing, the emulation will be
895 // done with atomic operations. E.g., if the stored value is i8, rewrite the
896 // StoreOp to:
897 // 1) load a 32-bit integer
898 // 2) clear 8 bits in the loaded value
899 // 3) set 8 bits in the loaded value
900 // 4) store 32-bit value back
901 //
902 // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
903 // loaded 32-bit value and the shifted 8-bit store value) as another atomic
904 // step.
905 assert(accessChainOp.getIndices().size() == 2);
906 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
907 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
908
909 // Create a mask to clear the destination. E.g., if it is the second i8 in
910 // i32, 0xFFFF00FF is created.
911 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
912 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
913 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
914 loc, dstType, mask, offset);
915 clearBitsMask =
916 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
917
918 Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
919 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
920 srcBits, dstBits, rewriter);
921 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
922 if (!scope)
923 return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
924
925 Value result = spirv::AtomicAndOp::create(
926 rewriter, loc, dstType, adjustedPtr, *scope,
927 spirv::MemorySemantics::AcquireRelease, clearBitsMask);
928 result = spirv::AtomicOrOp::create(
929 rewriter, loc, dstType, adjustedPtr, *scope,
930 spirv::MemorySemantics::AcquireRelease, storeVal);
931
932 // The AtomicOrOp has no side effect. Since it is already inserted, we can
933 // just remove the original StoreOp. Note that rewriter.replaceOp()
934 // doesn't work because it only accepts that the numbers of result are the
935 // same.
936 rewriter.eraseOp(storeOp);
937
938 assert(accessChainOp.use_empty());
939 rewriter.eraseOp(accessChainOp);
940
941 return success();
942}
943
944//===----------------------------------------------------------------------===//
945// MemorySpaceCastOp
946//===----------------------------------------------------------------------===//
947
948LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
949 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
950 ConversionPatternRewriter &rewriter) const {
951 Location loc = addrCastOp.getLoc();
952 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
953 if (!typeConverter.allows(spirv::Capability::Kernel))
954 return rewriter.notifyMatchFailure(
955 loc, "address space casts require kernel capability");
956
957 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
958 if (!sourceType)
959 return rewriter.notifyMatchFailure(
960 loc, "SPIR-V lowering requires ranked memref types");
961 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
962
963 auto sourceStorageClassAttr =
964 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
965 if (!sourceStorageClassAttr)
966 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
967 diag << "source address space " << sourceType.getMemorySpace()
968 << " must be a SPIR-V storage class";
969 });
970 auto resultStorageClassAttr =
971 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
972 if (!resultStorageClassAttr)
973 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
974 diag << "result address space " << resultType.getMemorySpace()
975 << " must be a SPIR-V storage class";
976 });
977
978 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
979 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
980
981 Value result = adaptor.getSource();
982 Type resultPtrType = typeConverter.convertType(resultType);
983 if (!resultPtrType)
984 return rewriter.notifyMatchFailure(addrCastOp,
985 "failed to convert memref type");
986
987 Type genericPtrType = resultPtrType;
988 // SPIR-V doesn't have a general address space cast operation. Instead, it has
989 // conversions to and from generic pointers. To implement the general case,
990 // we use specific-to-generic conversions when the source class is not
991 // generic. Then when the result storage class is not generic, we convert the
992 // generic pointer (either the input on ar intermediate result) to that
993 // class. This also means that we'll need the intermediate generic pointer
994 // type if neither the source or destination have it.
995 if (sourceSc != spirv::StorageClass::Generic &&
996 resultSc != spirv::StorageClass::Generic) {
997 Type intermediateType =
998 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
999 sourceType.getLayout(),
1000 rewriter.getAttr<spirv::StorageClassAttr>(
1001 spirv::StorageClass::Generic));
1002 genericPtrType = typeConverter.convertType(intermediateType);
1003 }
1004 if (sourceSc != spirv::StorageClass::Generic) {
1005 result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
1006 result);
1007 }
1008 if (resultSc != spirv::StorageClass::Generic) {
1009 result =
1010 spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result);
1011 }
1012 rewriter.replaceOp(addrCastOp, result);
1013 return success();
1014}
1015
1016LogicalResult
1017StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
1018 ConversionPatternRewriter &rewriter) const {
1019 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
1020 if (memrefType.getElementType().isSignlessInteger())
1021 return rewriter.notifyMatchFailure(storeOp, "signless int");
1022 auto storePtr = spirv::getElementPtr(
1023 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
1024 adaptor.getIndices(), storeOp.getLoc(), rewriter);
1025
1026 if (!storePtr)
1027 return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
1028
1029 auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
1030 if (failed(memoryRequirements))
1031 return rewriter.notifyMatchFailure(
1032 storeOp, "failed to determine memory requirements");
1033
1034 auto [memoryAccess, alignment] = *memoryRequirements;
1035 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1036 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1037 return success();
1038}
1039
1040LogicalResult ReinterpretCastPattern::matchAndRewrite(
1041 memref::ReinterpretCastOp op, OpAdaptor adaptor,
1042 ConversionPatternRewriter &rewriter) const {
1043 Value src = adaptor.getSource();
1044 auto srcType = dyn_cast<spirv::PointerType>(src.getType());
1045
1046 if (!srcType)
1047 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1048 diag << "invalid src type " << src.getType();
1049 });
1050
1051 const TypeConverter *converter = getTypeConverter();
1052
1053 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1054 if (dstType != srcType)
1055 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1056 diag << "invalid dst type " << op.getType();
1057 });
1058
1059 OpFoldResult offset =
1060 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1061 .front();
1062 if (isZeroInteger(offset)) {
1063 rewriter.replaceOp(op, src);
1064 return success();
1065 }
1066
1067 Type intType = converter->convertType(rewriter.getIndexType());
1068 if (!intType)
1069 return rewriter.notifyMatchFailure(op, "failed to convert index type");
1070
1071 Location loc = op.getLoc();
1072 auto offsetValue = [&]() -> Value {
1073 if (auto val = dyn_cast<Value>(offset))
1074 return val;
1075
1076 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1077 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1078 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1079 }();
1080
1081 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1082 op, src, offsetValue, ValueRange());
1083 return success();
1084}
1085
1086//===----------------------------------------------------------------------===//
1087// ExtractAlignedPointerAsIndexOp
1088//===----------------------------------------------------------------------===//
1089
1090LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1091 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1092 ConversionPatternRewriter &rewriter) const {
1093 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1094 Type indexType = typeConverter.getIndexType();
1095 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1096 adaptor.getSource());
1097 return success();
1098}
1099
1100//===----------------------------------------------------------------------===//
1101// Pattern population
1102//===----------------------------------------------------------------------===//
1103
1104namespace mlir {
1107 patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1108 DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1109 IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1110 StoreOpPattern, ReinterpretCastPattern, CastPattern,
1111 ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
1112 patterns.getContext());
1113}
1114} // namespace mlir
return success()
static Type getElementType(Type type)
Determine the element type of type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static FailureOr< SmallVector< Value > > extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:193
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
iterator begin()
Definition Region.h:55
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess