MLIR 23.0.0git
MemRefToSPIRV.cpp
Go to the documentation of this file.
1//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Visitors.h"
24#include <cassert>
25#include <limits>
26#include <optional>
27
28#define DEBUG_TYPE "memref-to-spirv-pattern"
29
30using namespace mlir;
31
32//===----------------------------------------------------------------------===//
33// Utility functions
34//===----------------------------------------------------------------------===//
35
36/// Returns the offset of the value in `targetBits` representation.
37///
38/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39/// It's assumed to be non-negative.
40///
41/// When accessing an element in the array treating as having elements of
42/// `targetBits`, multiple values are loaded in the same time. The method
43/// returns the offset where the `srcIdx` locates in the value. For example, if
44/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45/// located at (x % 4) * 8. Because there are four elements in one i32, and one
46/// element has 8 bits.
47static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
48 int targetBits, OpBuilder &builder) {
49 assert(targetBits % sourceBits == 0);
50 Type type = srcIdx.getType();
51 IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53 IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54 auto srcBitsValue =
55 builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56 auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57 return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
58}
59
60/// Returns an adjusted spirv::AccessChainOp. Based on the
61/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62/// supported. During conversion if a memref of an unsupported type is used,
63/// load/stores to this memref need to be modified to use a supported higher
64/// bitwidth `targetBits` and extracting the required bits. For an accessing a
65/// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66/// the bits needed. The extraction of the actual bits needed are handled
67/// separately. Note that this only works for a 1-D tensor.
68static Value
70 spirv::AccessChainOp op, int sourceBits,
71 int targetBits, OpBuilder &builder) {
72 assert(targetBits % sourceBits == 0);
73 const auto loc = op.getLoc();
74 Value lastDim = op->getOperand(op.getNumOperands() - 1);
75 Type type = lastDim.getType();
76 IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
78 auto indices = llvm::to_vector<4>(op.getIndices());
79 // There are two elements if this is a 1-D tensor.
80 assert(indices.size() == 2);
81 indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82 Type t = typeConverter.convertType(op.getComponentPtr().getType());
83 return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(),
84 indices);
85}
86
87/// Casts the given `srcBool` into an integer of `dstType`.
88static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
89 OpBuilder &builder) {
90 assert(srcBool.getType().isInteger(1));
91 if (dstType.isInteger(1))
92 return srcBool;
93 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
94 Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
95 return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
96 zero);
97}
98
99/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
100/// to the type destination type, and masked.
101static Value shiftValue(Location loc, Value value, Value offset, Value mask,
102 OpBuilder &builder) {
103 IntegerType dstType = cast<IntegerType>(mask.getType());
104 int targetBits = static_cast<int>(dstType.getWidth());
105 int valueBits = value.getType().getIntOrFloatBitWidth();
106 assert(valueBits <= targetBits);
107
108 if (valueBits == 1) {
109 value = castBoolToIntN(loc, value, dstType, builder);
110 } else {
111 if (valueBits < targetBits) {
112 value = spirv::UConvertOp::create(
113 builder, loc, builder.getIntegerType(targetBits), value);
114 }
115
116 value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
117 }
118 return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
119 value, offset);
120}
121
122/// Returns true if the allocations of memref `type` generated from `allocOp`
123/// can be lowered to SPIR-V.
124static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
125 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
126 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
127 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
128 return false;
129 } else if (isa<memref::AllocaOp>(allocOp)) {
130 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
131 if (!sc || sc.getValue() != spirv::StorageClass::Function)
132 return false;
133 } else {
134 return false;
135 }
136
137 // Currently only support static shape and int or float, complex of int or
138 // float, or vector of int or float element type.
139 if (!type.hasStaticShape())
140 return false;
141
142 Type elementType = type.getElementType();
143 if (auto vecType = dyn_cast<VectorType>(elementType))
144 elementType = vecType.getElementType();
145 if (auto compType = dyn_cast<ComplexType>(elementType))
146 elementType = compType.getElementType();
147 return elementType.isIntOrFloat();
148}
149
150/// Returns the scope to use for atomic operations use for emulating store
151/// operations of unsupported integer bitwidths, based on the memref
152/// type. Returns std::nullopt on failure.
153static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
154 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
155 switch (sc.getValue()) {
156 case spirv::StorageClass::StorageBuffer:
157 return spirv::Scope::Device;
158 case spirv::StorageClass::Workgroup:
159 return spirv::Scope::Workgroup;
160 default:
161 break;
162 }
163 return {};
164}
165
166/// Casts the given `srcInt` into a boolean value.
167static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
168 if (srcInt.getType().isInteger(1))
169 return srcInt;
170
171 auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
172 return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
173}
174
175//===----------------------------------------------------------------------===//
176// Operation conversion
177//===----------------------------------------------------------------------===//
178
179// Note that DRR cannot be used for the patterns in this file: we may need to
180// convert type along the way, which requires ConversionPattern. DRR generates
181// normal RewritePattern.
182
183namespace {
184
185/// Converts memref.alloca to SPIR-V Function variables.
186class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
187public:
188 using Base::Base;
189
190 LogicalResult
191 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter) const override;
193};
194
195/// Converts an allocation operation to SPIR-V. Currently only supports lowering
196/// to Workgroup memory when the size is constant. Note that this pattern needs
197/// to be applied in a pass that runs at least at spirv.module scope since it
198/// wil ladd global variables into the spirv.module.
199class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
200public:
201 using Base::Base;
202
203 LogicalResult
204 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter) const override;
206};
207
208/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
209class AtomicRMWOpPattern final
210 : public OpConversionPattern<memref::AtomicRMWOp> {
211public:
212 using Base::Base;
213
214 LogicalResult
215 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter) const override;
217};
218
219/// Removed a deallocation if it is a supported allocation. Currently only
220/// removes deallocation if the memory space is workgroup memory.
221class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
222public:
223 using Base::Base;
224
225 LogicalResult
226 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter) const override;
228};
229
230/// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
231class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
232public:
233 using Base::Base;
234
235 LogicalResult
236 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
237 ConversionPatternRewriter &rewriter) const override;
238};
239
240/// Converts memref.load to spirv.Load + spirv.AccessChain.
241class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
242public:
243 using Base::Base;
244
245 LogicalResult
246 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
247 ConversionPatternRewriter &rewriter) const override;
248};
249
250/// Converts memref.load to spirv.Image + spirv.ImageFetch
251class ImageLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
252public:
253 using Base::Base;
254
255 LogicalResult
256 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
257 ConversionPatternRewriter &rewriter) const override;
258};
259
260/// Converts memref.store to spirv.Store on integers.
261class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
262public:
263 using Base::Base;
264
265 LogicalResult
266 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
267 ConversionPatternRewriter &rewriter) const override;
268};
269
270/// Converts memref.memory_space_cast to the appropriate spirv cast operations.
271class MemorySpaceCastOpPattern final
272 : public OpConversionPattern<memref::MemorySpaceCastOp> {
273public:
274 using Base::Base;
275
276 LogicalResult
277 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter) const override;
279};
280
281/// Converts memref.store to spirv.Store.
282class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
283public:
284 using Base::Base;
285
286 LogicalResult
287 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override;
289};
290
291class ReinterpretCastPattern final
292 : public OpConversionPattern<memref::ReinterpretCastOp> {
293public:
294 using Base::Base;
295
296 LogicalResult
297 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter) const override;
299};
300
301class CastPattern final : public OpConversionPattern<memref::CastOp> {
302public:
303 using Base::Base;
304
305 LogicalResult
306 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
307 ConversionPatternRewriter &rewriter) const override {
308 Value src = adaptor.getSource();
309 Type srcType = src.getType();
310
311 const TypeConverter *converter = getTypeConverter();
312 Type dstType = converter->convertType(op.getType());
313 if (srcType != dstType)
314 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
315 diag << "types doesn't match: " << srcType << " and " << dstType;
316 });
317
318 rewriter.replaceOp(op, src);
319 return success();
320 }
321};
322
323/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
324class ExtractAlignedPointerAsIndexOpPattern final
325 : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
326public:
327 using Base::Base;
328
329 LogicalResult
330 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
331 OpAdaptor adaptor,
332 ConversionPatternRewriter &rewriter) const override;
333};
334} // namespace
335
336//===----------------------------------------------------------------------===//
337// AllocaOp
338//===----------------------------------------------------------------------===//
339
340LogicalResult
341AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter) const {
343 MemRefType allocType = allocaOp.getType();
344 if (!isAllocationSupported(allocaOp, allocType))
345 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
346
347 // Get the SPIR-V type for the allocation.
348 Type spirvType = getTypeConverter()->convertType(allocType);
349 if (!spirvType)
350 return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
351
352 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
353 spirv::StorageClass::Function,
354 /*initializer=*/nullptr);
355 return success();
356}
357
358//===----------------------------------------------------------------------===//
359// AllocOp
360//===----------------------------------------------------------------------===//
361
362LogicalResult
363AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
364 ConversionPatternRewriter &rewriter) const {
365 MemRefType allocType = operation.getType();
366 if (!isAllocationSupported(operation, allocType))
367 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
368
369 // Get the SPIR-V type for the allocation.
370 Type spirvType = getTypeConverter()->convertType(allocType);
371 if (!spirvType)
372 return rewriter.notifyMatchFailure(operation, "type conversion failed");
373
374 // Insert spirv.GlobalVariable for this allocation.
375 Operation *parent =
376 SymbolTable::getNearestSymbolTable(operation->getParentOp());
377 if (!parent)
378 return failure();
379 Location loc = operation.getLoc();
380 spirv::GlobalVariableOp varOp;
381 {
382 OpBuilder::InsertionGuard guard(rewriter);
383 Block &entryBlock = *parent->getRegion(0).begin();
384 rewriter.setInsertionPointToStart(&entryBlock);
385 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
386 std::string varName =
387 std::string("__workgroup_mem__") +
388 std::to_string(std::distance(varOps.begin(), varOps.end()));
389 varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName,
390 /*initializer=*/nullptr);
391 }
392
393 // Get pointer to global variable at the current scope.
394 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
395 return success();
396}
397
398//===----------------------------------------------------------------------===//
399// AllocOp
400//===----------------------------------------------------------------------===//
401
402LogicalResult
403AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
404 OpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter) const {
406 if (isa<FloatType>(atomicOp.getType()))
407 return rewriter.notifyMatchFailure(atomicOp,
408 "unimplemented floating-point case");
409
410 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
411 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
412 if (!scope)
413 return rewriter.notifyMatchFailure(atomicOp,
414 "unsupported memref memory space");
415
416 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
417 Type resultType = typeConverter.convertType(atomicOp.getType());
418 if (!resultType)
419 return rewriter.notifyMatchFailure(atomicOp,
420 "failed to convert result type");
421
422 auto loc = atomicOp.getLoc();
423 Value ptr =
424 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
425 adaptor.getIndices(), loc, rewriter);
426
427 if (!ptr)
428 return failure();
429
430#define ATOMIC_CASE(kind, spirvOp) \
431 case arith::AtomicRMWKind::kind: \
432 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
433 atomicOp, resultType, ptr, *scope, \
434 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
435 break
436
437 switch (atomicOp.getKind()) {
438 ATOMIC_CASE(addi, AtomicIAddOp);
439 ATOMIC_CASE(maxs, AtomicSMaxOp);
440 ATOMIC_CASE(maxu, AtomicUMaxOp);
441 ATOMIC_CASE(mins, AtomicSMinOp);
442 ATOMIC_CASE(minu, AtomicUMinOp);
443 ATOMIC_CASE(ori, AtomicOrOp);
444 ATOMIC_CASE(andi, AtomicAndOp);
445 default:
446 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
447 }
448
449#undef ATOMIC_CASE
450
451 return success();
452}
453
454//===----------------------------------------------------------------------===//
455// DeallocOp
456//===----------------------------------------------------------------------===//
457
458LogicalResult
459DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
460 OpAdaptor adaptor,
461 ConversionPatternRewriter &rewriter) const {
462 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
463 if (!isAllocationSupported(operation, deallocType))
464 return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
465 rewriter.eraseOp(operation);
466 return success();
467}
468
469//===----------------------------------------------------------------------===//
470// LoadOp
471//===----------------------------------------------------------------------===//
472
474 spirv::MemoryAccessAttr memoryAccess;
475 IntegerAttr alignment;
476};
477
478/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
479/// any.
480static FailureOr<MemoryRequirements>
481calculateMemoryRequirements(Value accessedPtr, bool isNontemporal,
482 uint64_t preferredAlignment) {
483 if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) {
484 return failure();
485 }
486
487 MLIRContext *ctx = accessedPtr.getContext();
488
489 auto memoryAccess = spirv::MemoryAccess::None;
490 if (isNontemporal) {
491 memoryAccess = spirv::MemoryAccess::Nontemporal;
492 }
493
494 auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
495 bool mayOmitAlignment =
496 !preferredAlignment &&
497 ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer;
498 if (mayOmitAlignment) {
499 if (memoryAccess == spirv::MemoryAccess::None) {
500 return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
501 }
502 return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
503 IntegerAttr{}};
504 }
505
506 // PhysicalStorageBuffers require the `Aligned` attribute.
507 // Other storage types may show an `Aligned` attribute.
508 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
509 if (!pointeeType)
510 return failure();
511
512 // For scalar types, the alignment is determined by their size.
513 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
514 if (!sizeInBytes.has_value())
515 return failure();
516
517 memoryAccess |= spirv::MemoryAccess::Aligned;
518 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
519 auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes;
520 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue);
521 return MemoryRequirements{memAccessAttr, alignment};
522}
523
524/// Given an accessed SPIR-V pointer and the original memref load/store
525/// `memAccess` op, calculates the alignment requirements, if any. Takes into
526/// account the alignment attributes applied to the load/store op.
527template <class LoadOrStoreOp>
528static FailureOr<MemoryRequirements>
529calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
530 static_assert(
531 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
532 "Must be called on either memref::LoadOp or memref::StoreOp");
533
534 return calculateMemoryRequirements(accessedPtr,
535 loadOrStoreOp.getNontemporal(),
536 loadOrStoreOp.getAlignment().value_or(0));
537}
538
539LogicalResult
540IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
541 ConversionPatternRewriter &rewriter) const {
542 auto loc = loadOp.getLoc();
543 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
544 if (!memrefType.getElementType().isSignlessInteger())
545 return failure();
546
547 auto memorySpaceAttr =
548 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
549 if (!memorySpaceAttr)
550 return rewriter.notifyMatchFailure(
551 loadOp, "missing memory space SPIR-V storage class attribute");
552
553 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
554 return rewriter.notifyMatchFailure(
555 loadOp,
556 "failed to lower memref in image storage class to storage buffer");
557
558 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
559 Value accessChain =
560 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
561 adaptor.getIndices(), loc, rewriter);
562
563 if (!accessChain)
564 return failure();
565
566 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
567 bool isBool = srcBits == 1;
568 if (isBool)
569 srcBits = typeConverter.getOptions().boolNumBits;
570
571 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
572 if (!pointerType)
573 return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
574
575 Type pointeeType = pointerType.getPointeeType();
576 Type dstType;
577 if (typeConverter.allows(spirv::Capability::Kernel)) {
578 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
579 dstType = arrayType.getElementType();
580 else
581 dstType = pointeeType;
582 } else {
583 // For Vulkan we need to extract element from wrapping struct and array.
584 Type structElemType =
585 cast<spirv::StructType>(pointeeType).getElementType(0);
586 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
587 dstType = arrayType.getElementType();
588 else
589 dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
590 }
591 int dstBits = dstType.getIntOrFloatBitWidth();
592 assert(dstBits % srcBits == 0);
593
594 // If the rewritten load op has the same bit width, use the loading value
595 // directly.
596 if (srcBits == dstBits) {
597 auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
598 if (failed(memoryRequirements))
599 return rewriter.notifyMatchFailure(
600 loadOp, "failed to determine memory requirements");
601
602 auto [memoryAccess, alignment] = *memoryRequirements;
603 Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain,
604 memoryAccess, alignment);
605 if (isBool)
606 loadVal = castIntNToBool(loc, loadVal, rewriter);
607 rewriter.replaceOp(loadOp, loadVal);
608 return success();
609 }
610
611 // Bitcasting is currently unsupported for Kernel capability /
612 // spirv.PtrAccessChain.
613 if (typeConverter.allows(spirv::Capability::Kernel))
614 return failure();
615
616 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
617 if (!accessChainOp)
618 return failure();
619
620 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
621 // still returns a linearized accessing. If the accessing is not linearized,
622 // there will be offset issues.
623 assert(accessChainOp.getIndices().size() == 2);
624 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
625 srcBits, dstBits, rewriter);
626 auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
627 if (failed(memoryRequirements))
628 return rewriter.notifyMatchFailure(
629 loadOp, "failed to determine memory requirements");
630
631 auto [memoryAccess, alignment] = *memoryRequirements;
632 Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr,
633 memoryAccess, alignment);
634
635 // Shift the bits to the rightmost.
636 // ____XXXX________ -> ____________XXXX
637 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
638 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
639 Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
640 loc, spvLoadOp.getType(), spvLoadOp, offset);
641
642 // Apply the mask to extract corresponding bits.
643 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
644 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
645 result =
646 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
647
648 // Apply sign extension on the loading value unconditionally. The signedness
649 // semantic is carried in the operator itself, we relies other pattern to
650 // handle the casting.
651 IntegerAttr shiftValueAttr =
652 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
653 Value shiftValue =
654 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
655 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
657 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
658 loc, dstType, result, shiftValue);
659
660 rewriter.replaceOp(loadOp, result);
661
662 assert(accessChainOp.use_empty());
663 rewriter.eraseOp(accessChainOp);
664
665 return success();
666}
667
668LogicalResult
669LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
670 ConversionPatternRewriter &rewriter) const {
671 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
672 if (memrefType.getElementType().isSignlessInteger())
673 return failure();
674
675 auto memorySpaceAttr =
676 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
677 if (!memorySpaceAttr)
678 return rewriter.notifyMatchFailure(
679 loadOp, "missing memory space SPIR-V storage class attribute");
680
681 if (memorySpaceAttr.getValue() == spirv::StorageClass::Image)
682 return rewriter.notifyMatchFailure(
683 loadOp,
684 "failed to lower memref in image storage class to storage buffer");
685
686 Value loadPtr = spirv::getElementPtr(
687 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
688 adaptor.getIndices(), loadOp.getLoc(), rewriter);
689
690 if (!loadPtr)
691 return failure();
692
693 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
694 if (failed(memoryRequirements))
695 return rewriter.notifyMatchFailure(
696 loadOp, "failed to determine memory requirements");
697
698 auto [memoryAccess, alignment] = *memoryRequirements;
699 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
700 alignment);
701 return success();
702}
703
704template <typename OpAdaptor>
705static FailureOr<SmallVector<Value>>
706extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor,
707 ConversionPatternRewriter &rewriter) {
708 // At present we only support linear "tiling" as specified in Vulkan, this
709 // means that texels are assumed to be laid out in memory in a row-major
710 // order. This allows us to support any memref layout that is a permutation of
711 // the dimensions. Future work will pass an optional image layout to the
712 // rewrite pattern so that we can support optimized target specific tilings.
713 SmallVector<Value> indices = adaptor.getIndices();
714 AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
715 if (!map.isPermutation())
716 return rewriter.notifyMatchFailure(
717 loadOp,
718 "Cannot lower memrefs with memory layout which is not a permutation");
719
720 // The memrefs layout determines the dimension ordering so we need to follow
721 // the map to get the ordering of the dimensions/indices.
722 const unsigned dimCount = map.getNumDims();
723 SmallVector<Value, 3> coords(dimCount);
724 for (unsigned dim = 0; dim < dimCount; ++dim)
725 coords[map.getDimPosition(dim)] = indices[dim];
726
727 // We need to reverse the coordinates because the memref layout is slowest to
728 // fastest moving and the vector coordinates for the image op is fastest to
729 // slowest moving.
730 return llvm::to_vector(llvm::reverse(coords));
731}
732
733LogicalResult
734ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
735 ConversionPatternRewriter &rewriter) const {
736 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
737
738 auto memorySpaceAttr =
739 dyn_cast_if_present<spirv::StorageClassAttr>(memrefType.getMemorySpace());
740 if (!memorySpaceAttr)
741 return rewriter.notifyMatchFailure(
742 loadOp, "missing memory space SPIR-V storage class attribute");
743
744 if (memorySpaceAttr.getValue() != spirv::StorageClass::Image)
745 return rewriter.notifyMatchFailure(
746 loadOp, "failed to lower memref in non-image storage class to image");
747
748 Value loadPtr = adaptor.getMemref();
749 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
750 if (failed(memoryRequirements))
751 return rewriter.notifyMatchFailure(
752 loadOp, "failed to determine memory requirements");
753
754 const auto [memoryAccess, alignment] = *memoryRequirements;
755
756 if (!loadOp.getMemRefType().hasRank())
757 return rewriter.notifyMatchFailure(
758 loadOp, "cannot lower unranked memrefs to SPIR-V images");
759
760 // We currently only support lowering of scalar memref elements to texels in
761 // the R[16|32][f|i|ui] formats. Future work will enable lowering of vector
762 // elements to texels in richer formats.
763 if (!isa<spirv::ScalarType>(loadOp.getMemRefType().getElementType()))
764 return rewriter.notifyMatchFailure(
765 loadOp,
766 "cannot lower memrefs who's element type is not a SPIR-V scalar type"
767 "to SPIR-V images");
768
769 // We currently only support sampled images since OpImageFetch does not work
770 // for plain images and the OpImageRead instruction needs to be materialized
771 // instead or texels need to be accessed via atomics through a texel pointer.
772 // Future work will generalize support to plain images.
773 auto convertedPointeeType = cast<spirv::PointerType>(
774 getTypeConverter()->convertType(loadOp.getMemRefType()));
775 if (!isa<spirv::SampledImageType>(convertedPointeeType.getPointeeType()))
776 return rewriter.notifyMatchFailure(loadOp,
777 "cannot lower memrefs which do not "
778 "convert to SPIR-V sampled images");
779
780 // Materialize the lowering.
781 Location loc = loadOp->getLoc();
782 auto imageLoadOp =
783 spirv::LoadOp::create(rewriter, loc, loadPtr, memoryAccess, alignment);
784 // Extract the image from the sampled image.
785 auto imageOp = spirv::ImageOp::create(rewriter, loc, imageLoadOp);
786
787 // Build a vector of coordinates or just a scalar index if we have a 1D image.
788 Value coords;
789 if (memrefType.getRank() == 1) {
790 coords = adaptor.getIndices()[0];
791 } else {
792 FailureOr<SmallVector<Value>> maybeCoords =
793 extractLoadCoordsForComposite(loadOp, adaptor, rewriter);
794 if (failed(maybeCoords))
795 return failure();
796 auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
797 adaptor.getIndices().getType()[0]);
798 coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
799 maybeCoords.value());
800 }
801
802 // Fetch the value out of the image.
803 auto resultVectorType = VectorType::get({4}, loadOp.getType());
804 auto fetchOp = spirv::ImageFetchOp::create(
805 rewriter, loc, resultVectorType, imageOp, coords,
806 mlir::spirv::ImageOperandsAttr{}, ValueRange{});
807
808 // Note that because OpImageFetch returns a rank 4 vector we need to extract
809 // the elements corresponding to the load which will since we only support the
810 // R[16|32][f|i|ui] formats will always be the R(red) 0th vector element.
811 auto compositeExtractOp =
812 spirv::CompositeExtractOp::create(rewriter, loc, fetchOp, 0);
813
814 rewriter.replaceOp(loadOp, compositeExtractOp);
815 return success();
816}
817
818LogicalResult
819IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
820 ConversionPatternRewriter &rewriter) const {
821 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
822 if (!memrefType.getElementType().isSignlessInteger())
823 return rewriter.notifyMatchFailure(storeOp,
824 "element type is not a signless int");
825
826 auto loc = storeOp.getLoc();
827 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
828 Value accessChain =
829 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
830 adaptor.getIndices(), loc, rewriter);
831
832 if (!accessChain)
833 return rewriter.notifyMatchFailure(
834 storeOp, "failed to convert element pointer type");
835
836 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
837
838 bool isBool = srcBits == 1;
839 if (isBool)
840 srcBits = typeConverter.getOptions().boolNumBits;
841
842 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
843 if (!pointerType)
844 return rewriter.notifyMatchFailure(storeOp,
845 "failed to convert memref type");
846
847 Type pointeeType = pointerType.getPointeeType();
848 IntegerType dstType;
849 if (typeConverter.allows(spirv::Capability::Kernel)) {
850 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
851 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
852 else
853 dstType = dyn_cast<IntegerType>(pointeeType);
854 } else {
855 // For Vulkan we need to extract element from wrapping struct and array.
856 Type structElemType =
857 cast<spirv::StructType>(pointeeType).getElementType(0);
858 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
859 dstType = dyn_cast<IntegerType>(arrayType.getElementType());
860 else
861 dstType = dyn_cast<IntegerType>(
862 cast<spirv::RuntimeArrayType>(structElemType).getElementType());
863 }
864
865 if (!dstType)
866 return rewriter.notifyMatchFailure(
867 storeOp, "failed to determine destination element type");
868
869 int dstBits = static_cast<int>(dstType.getWidth());
870 assert(dstBits % srcBits == 0);
871
872 if (srcBits == dstBits) {
873 auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
874 if (failed(memoryRequirements))
875 return rewriter.notifyMatchFailure(
876 storeOp, "failed to determine memory requirements");
877
878 auto [memoryAccess, alignment] = *memoryRequirements;
879 Value storeVal = adaptor.getValue();
880 if (isBool)
881 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
882 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
883 memoryAccess, alignment);
884 return success();
885 }
886
887 // Bitcasting is currently unsupported for Kernel capability /
888 // spirv.PtrAccessChain.
889 if (typeConverter.allows(spirv::Capability::Kernel))
890 return failure();
891
892 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
893 if (!accessChainOp)
894 return failure();
895
896 // Since there are multiple threads in the processing, the emulation will be
897 // done with atomic operations. E.g., if the stored value is i8, rewrite the
898 // StoreOp to:
899 // 1) load a 32-bit integer
900 // 2) clear 8 bits in the loaded value
901 // 3) set 8 bits in the loaded value
902 // 4) store 32-bit value back
903 //
904 // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
905 // loaded 32-bit value and the shifted 8-bit store value) as another atomic
906 // step.
907 assert(accessChainOp.getIndices().size() == 2);
908 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
909 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
910
911 // Create a mask to clear the destination. E.g., if it is the second i8 in
912 // i32, 0xFFFF00FF is created.
913 Value mask = rewriter.createOrFold<spirv::ConstantOp>(
914 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
915 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
916 loc, dstType, mask, offset);
917 clearBitsMask =
918 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
919
920 Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
921 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
922 srcBits, dstBits, rewriter);
923 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
924 if (!scope)
925 return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
926
927 Value result = spirv::AtomicAndOp::create(
928 rewriter, loc, dstType, adjustedPtr, *scope,
929 spirv::MemorySemantics::AcquireRelease, clearBitsMask);
930 result = spirv::AtomicOrOp::create(
931 rewriter, loc, dstType, adjustedPtr, *scope,
932 spirv::MemorySemantics::AcquireRelease, storeVal);
933
934 // The AtomicOrOp has no side effect. Since it is already inserted, we can
935 // just remove the original StoreOp. Note that rewriter.replaceOp()
936 // doesn't work because it only accepts that the numbers of result are the
937 // same.
938 rewriter.eraseOp(storeOp);
939
940 assert(accessChainOp.use_empty());
941 rewriter.eraseOp(accessChainOp);
942
943 return success();
944}
945
946//===----------------------------------------------------------------------===//
947// MemorySpaceCastOp
948//===----------------------------------------------------------------------===//
949
950LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
951 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
952 ConversionPatternRewriter &rewriter) const {
953 Location loc = addrCastOp.getLoc();
954 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
955 if (!typeConverter.allows(spirv::Capability::Kernel))
956 return rewriter.notifyMatchFailure(
957 loc, "address space casts require kernel capability");
958
959 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
960 if (!sourceType)
961 return rewriter.notifyMatchFailure(
962 loc, "SPIR-V lowering requires ranked memref types");
963 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
964
965 auto sourceStorageClassAttr =
966 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
967 if (!sourceStorageClassAttr)
968 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
969 diag << "source address space " << sourceType.getMemorySpace()
970 << " must be a SPIR-V storage class";
971 });
972 auto resultStorageClassAttr =
973 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
974 if (!resultStorageClassAttr)
975 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
976 diag << "result address space " << resultType.getMemorySpace()
977 << " must be a SPIR-V storage class";
978 });
979
980 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
981 spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
982
983 Value result = adaptor.getSource();
984 Type resultPtrType = typeConverter.convertType(resultType);
985 if (!resultPtrType)
986 return rewriter.notifyMatchFailure(addrCastOp,
987 "failed to convert memref type");
988
989 Type genericPtrType = resultPtrType;
990 // SPIR-V doesn't have a general address space cast operation. Instead, it has
991 // conversions to and from generic pointers. To implement the general case,
992 // we use specific-to-generic conversions when the source class is not
993 // generic. Then when the result storage class is not generic, we convert the
994 // generic pointer (either the input on ar intermediate result) to that
995 // class. This also means that we'll need the intermediate generic pointer
996 // type if neither the source or destination have it.
997 if (sourceSc != spirv::StorageClass::Generic &&
998 resultSc != spirv::StorageClass::Generic) {
999 Type intermediateType =
1000 MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
1001 sourceType.getLayout(),
1002 rewriter.getAttr<spirv::StorageClassAttr>(
1003 spirv::StorageClass::Generic));
1004 genericPtrType = typeConverter.convertType(intermediateType);
1005 }
1006 if (sourceSc != spirv::StorageClass::Generic) {
1007 result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType,
1008 result);
1009 }
1010 if (resultSc != spirv::StorageClass::Generic) {
1011 result =
1012 spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result);
1013 }
1014 rewriter.replaceOp(addrCastOp, result);
1015 return success();
1016}
1017
1018LogicalResult
1019StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
1020 ConversionPatternRewriter &rewriter) const {
1021 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
1022 if (memrefType.getElementType().isSignlessInteger())
1023 return rewriter.notifyMatchFailure(storeOp, "signless int");
1024 auto storePtr = spirv::getElementPtr(
1025 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
1026 adaptor.getIndices(), storeOp.getLoc(), rewriter);
1027
1028 if (!storePtr)
1029 return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
1030
1031 auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
1032 if (failed(memoryRequirements))
1033 return rewriter.notifyMatchFailure(
1034 storeOp, "failed to determine memory requirements");
1035
1036 auto [memoryAccess, alignment] = *memoryRequirements;
1037 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1038 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
1039 return success();
1040}
1041
1042LogicalResult ReinterpretCastPattern::matchAndRewrite(
1043 memref::ReinterpretCastOp op, OpAdaptor adaptor,
1044 ConversionPatternRewriter &rewriter) const {
1045 Value src = adaptor.getSource();
1046 auto srcType = dyn_cast<spirv::PointerType>(src.getType());
1047
1048 if (!srcType)
1049 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1050 diag << "invalid src type " << src.getType();
1051 });
1052
1053 const TypeConverter *converter = getTypeConverter();
1054
1055 auto dstType = converter->convertType<spirv::PointerType>(op.getType());
1056 if (dstType != srcType)
1057 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1058 diag << "invalid dst type " << op.getType();
1059 });
1060
1061 OpFoldResult offset =
1062 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
1063 .front();
1064 if (isZeroInteger(offset)) {
1065 rewriter.replaceOp(op, src);
1066 return success();
1067 }
1068
1069 Type intType = converter->convertType(rewriter.getIndexType());
1070 if (!intType)
1071 return rewriter.notifyMatchFailure(op, "failed to convert index type");
1072
1073 Location loc = op.getLoc();
1074 auto offsetValue = [&]() -> Value {
1075 if (auto val = dyn_cast<Value>(offset))
1076 return val;
1077
1078 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
1079 Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
1080 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
1081 }();
1082
1083 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
1084 op, src, offsetValue, ValueRange());
1085 return success();
1086}
1087
1088//===----------------------------------------------------------------------===//
1089// ExtractAlignedPointerAsIndexOp
1090//===----------------------------------------------------------------------===//
1091
1092LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
1093 memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
1094 ConversionPatternRewriter &rewriter) const {
1095 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1096 Type indexType = typeConverter.getIndexType();
1097 rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
1098 adaptor.getSource());
1099 return success();
1100}
1101
1102//===----------------------------------------------------------------------===//
1103// Pattern population
1104//===----------------------------------------------------------------------===//
1105
1106namespace mlir {
1109 patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
1110 DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
1111 IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
1112 StoreOpPattern, ReinterpretCastPattern, CastPattern,
1113 ExtractAlignedPointerAsIndexOpPattern>(typeConverter,
1114 patterns.getContext());
1115}
1116} // namespace mlir
return success()
static Type getElementType(Type type)
Determine the element type of type.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder)
Casts the given srcInt into a boolean value.
static std::optional< spirv::Scope > getAtomicOpScope(MemRefType type)
Returns the scope to use for atomic operations use for emulating store operations of unsupported inte...
static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder)
Returns the targetBits-bit value shifted by the given offset, and cast to the type destination type,...
static FailureOr< SmallVector< Value > > extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder)
Returns an adjusted spirv::AccessChainOp.
static bool isAllocationSupported(Operation *allocOp, MemRefType type)
Returns true if the allocations of memref type generated from allocOp can be lowered to SPIR-V.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder)
Returns the offset of the value in targetBits representation.
#define ATOMIC_CASE(kind, spirvOp)
static FailureOr< MemoryRequirements > calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment)
Given an accessed SPIR-V pointer, calculates its alignment requirements, if any.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder)
Casts the given srcBool into an integer of dstType.
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
iterator begin()
Definition Region.h:55
Type conversion from builtin types to SPIR-V types for shader interface.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating MemRef ops to SPIR-V ops.
spirv::MemoryAccessAttr memoryAccess