MLIR 22.0.0git
UnifyAliasedResourcePass.cpp
Go to the documentation of this file.
1//===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass that unifies access of multiple aliased resources
10// into access of one single resource.
11//
12//===----------------------------------------------------------------------===//
13
15
20#include "mlir/IR/Builders.h"
23#include "mlir/IR/SymbolTable.h"
25#include "llvm/ADT/DenseMap.h"
26#include "llvm/ADT/STLExtras.h"
27#include <iterator>
28
29namespace mlir {
30namespace spirv {
31#define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
32#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
33} // namespace spirv
34} // namespace mlir
35
36using namespace mlir;
37
38//===----------------------------------------------------------------------===//
39// Utility functions
40//===----------------------------------------------------------------------===//
41
42using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
45
46/// Collects all aliased resources in the given SPIR-V `moduleOp`.
47static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
48 AliasedResourceMap aliasedResources;
49 moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
50 if (varOp->getAttrOfType<UnitAttr>("aliased")) {
51 std::optional<uint32_t> set = varOp.getDescriptorSet();
52 std::optional<uint32_t> binding = varOp.getBinding();
53 if (set && binding)
54 aliasedResources[{*set, *binding}].push_back(varOp);
55 }
56 });
57 return aliasedResources;
58}
59
60/// Returns the element type if the given `type` is a runtime array resource:
61/// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type
62/// otherwise.
64 auto ptrType = dyn_cast<spirv::PointerType>(type);
65 if (!ptrType)
66 return {};
67
68 auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
69 if (!structType || structType.getNumElements() != 1)
70 return {};
71
72 auto rtArrayType =
73 dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0));
74 if (!rtArrayType)
75 return {};
76
77 return rtArrayType.getElementType();
78}
79
80/// Given a list of resource element `types`, returns the index of the canonical
81/// resource that all resources should be unified into. Returns std::nullopt if
82/// unable to unify.
83static std::optional<int>
85 // scalarNumBits: contains all resources' scalar types' bit counts.
86 // vectorNumBits: only contains resources whose element types are vectors.
87 // vectorIndices: each vector's original index in `types`.
88 SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
89 scalarNumBits.reserve(types.size());
90 vectorNumBits.reserve(types.size());
91 vectorIndices.reserve(types.size());
92
93 for (const auto &indexedTypes : llvm::enumerate(types)) {
94 spirv::SPIRVType type = indexedTypes.value();
95 assert(type.isScalarOrVector());
96 if (auto vectorType = dyn_cast<VectorType>(type)) {
97 if (vectorType.getNumElements() % 2 != 0)
98 return std::nullopt; // Odd-sized vector has special layout
99 // requirements.
100
101 std::optional<int64_t> numBytes = type.getSizeInBytes();
102 if (!numBytes)
103 return std::nullopt;
104
105 scalarNumBits.push_back(
106 vectorType.getElementType().getIntOrFloatBitWidth());
107 vectorNumBits.push_back(*numBytes * 8);
108 vectorIndices.push_back(indexedTypes.index());
109 } else {
110 scalarNumBits.push_back(type.getIntOrFloatBitWidth());
111 }
112 }
113
114 if (!vectorNumBits.empty()) {
115 // Choose the *vector* with the smallest bitwidth as the canonical resource,
116 // so that we can still keep vectorized load/store and avoid partial updates
117 // to large vectors.
118 auto *minVal = llvm::min_element(vectorNumBits);
119 // Make sure that the canonical resource's bitwidth is divisible by others.
120 // With out this, we cannot properly adjust the index later.
121 if (llvm::any_of(vectorNumBits,
122 [&](int bits) { return bits % *minVal != 0; }))
123 return std::nullopt;
124
125 // Require all scalar type bit counts to be a multiple of the chosen
126 // vector's primitive type to avoid reading/writing subcomponents.
127 int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
128 int baseNumBits = scalarNumBits[index];
129 if (llvm::any_of(scalarNumBits,
130 [&](int bits) { return bits % baseNumBits != 0; }))
131 return std::nullopt;
132
133 return index;
134 }
135
136 // All element types are scalars. Then choose the smallest bitwidth as the
137 // cannonical resource to avoid subcomponent load/store.
138 auto *minVal = llvm::min_element(scalarNumBits);
139 if (llvm::any_of(scalarNumBits,
140 [minVal](int64_t bit) { return bit % *minVal != 0; }))
141 return std::nullopt;
142 return std::distance(scalarNumBits.begin(), minVal);
143}
144
146 return a.isIntOrFloat() && b.isIntOrFloat() &&
147 a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
148}
149
150//===----------------------------------------------------------------------===//
151// Analysis
152//===----------------------------------------------------------------------===//
153
154namespace {
155/// A class for analyzing aliased resources.
156///
157/// Resources are expected to be spirv.GlobalVarible that has a descriptor set
158/// and binding number. Such resources are of the type
159/// `!spirv.ptr<!spirv.struct<...>>` per Vulkan requirements.
160///
161/// Right now, we only support the case that there is a single runtime array
162/// inside the struct.
163class ResourceAliasAnalysis {
164public:
166
167 explicit ResourceAliasAnalysis(Operation *);
168
169 /// Returns true if the given `op` can be rewritten to use a canonical
170 /// resource.
171 bool shouldUnify(Operation *op) const;
172
173 /// Returns all descriptors and their corresponding aliased resources.
174 const AliasedResourceMap &getResourceMap() const { return resourceMap; }
175
176 /// Returns the canonical resource for the given descriptor/variable.
177 spirv::GlobalVariableOp
178 getCanonicalResource(const Descriptor &descriptor) const;
179 spirv::GlobalVariableOp
180 getCanonicalResource(spirv::GlobalVariableOp varOp) const;
181
182 /// Returns the element type for the given variable.
183 spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
184
185private:
186 /// Given the descriptor and aliased resources bound to it, analyze whether we
187 /// can unify them and record if so.
188 void recordIfUnifiable(const Descriptor &descriptor,
189 ArrayRef<spirv::GlobalVariableOp> resources);
190
191 /// Mapping from a descriptor to all aliased resources bound to it.
192 AliasedResourceMap resourceMap;
193
194 /// Mapping from a descriptor to the chosen canonical resource.
196
197 /// Mapping from an aliased resource to its descriptor.
199
200 /// Mapping from an aliased resource to its element (scalar/vector) type.
202};
203} // namespace
204
205ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
206 // Collect all aliased resources first and put them into different sets
207 // according to the descriptor.
208 AliasedResourceMap aliasedResources =
209 collectAliasedResources(cast<spirv::ModuleOp>(root));
210
211 // For each resource set, analyze whether we can unify; if so, try to identify
212 // a canonical resource, whose element type has the largest bitwidth.
213 for (const auto &descriptorResource : aliasedResources) {
214 recordIfUnifiable(descriptorResource.first, descriptorResource.second);
215 }
216}
217
218bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
219 if (!op)
220 return false;
221
222 if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
223 auto canonicalOp = getCanonicalResource(varOp);
224 return canonicalOp && varOp != canonicalOp;
225 }
226 if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
227 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
228 auto *varOp =
229 SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
230 return shouldUnify(varOp);
231 }
232
233 if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
234 return shouldUnify(acOp.getBasePtr().getDefiningOp());
235 if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
236 return shouldUnify(loadOp.getPtr().getDefiningOp());
237 if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
238 return shouldUnify(storeOp.getPtr().getDefiningOp());
239
240 return false;
241}
242
243spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
244 const Descriptor &descriptor) const {
245 auto varIt = canonicalResourceMap.find(descriptor);
246 if (varIt == canonicalResourceMap.end())
247 return {};
248 return varIt->second;
249}
250
251spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
252 spirv::GlobalVariableOp varOp) const {
253 auto descriptorIt = descriptorMap.find(varOp);
254 if (descriptorIt == descriptorMap.end())
255 return {};
256 return getCanonicalResource(descriptorIt->second);
257}
258
259spirv::SPIRVType
260ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
261 auto it = elementTypeMap.find(varOp);
262 if (it == elementTypeMap.end())
263 return {};
264 return it->second;
265}
266
267void ResourceAliasAnalysis::recordIfUnifiable(
268 const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
269 // Collect the element types for all resources in the current set.
270 SmallVector<spirv::SPIRVType> elementTypes;
271 for (spirv::GlobalVariableOp resource : resources) {
272 Type elementType = getRuntimeArrayElementType(resource.getType());
273 if (!elementType)
274 return; // Unexpected resource variable type.
275
276 auto type = cast<spirv::SPIRVType>(elementType);
277 if (!type.isScalarOrVector())
278 return; // Unexpected resource element type.
279
280 elementTypes.push_back(type);
281 }
282
283 std::optional<int> index = deduceCanonicalResource(elementTypes);
284 if (!index)
285 return;
286
287 // Update internal data structures for later use.
288 resourceMap[descriptor].assign(resources.begin(), resources.end());
289 canonicalResourceMap[descriptor] = resources[*index];
290 for (const auto &resource : llvm::enumerate(resources)) {
291 descriptorMap[resource.value()] = descriptor;
292 elementTypeMap[resource.value()] = elementTypes[resource.index()];
293 }
294}
295
296//===----------------------------------------------------------------------===//
297// Patterns
298//===----------------------------------------------------------------------===//
299
300template <typename OpTy>
301class ConvertAliasResource : public OpConversionPattern<OpTy> {
302public:
303 ConvertAliasResource(const ResourceAliasAnalysis &analysis,
304 MLIRContext *context, PatternBenefit benefit = 1)
305 : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
306
307protected:
308 const ResourceAliasAnalysis &analysis;
309};
310
311struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> {
313
314 LogicalResult
315 matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
316 ConversionPatternRewriter &rewriter) const override {
317 // Just remove the aliased resource. Users will be rewritten to use the
318 // canonical one.
319 rewriter.eraseOp(varOp);
320 return success();
321 }
322};
323
324struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
326
327 LogicalResult
328 matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter) const override {
330 // Rewrite the AddressOf op to get the address of the canoncical resource.
331 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
332 auto srcVarOp = cast<spirv::GlobalVariableOp>(
333 SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
334 auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
335 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
336 return success();
337 }
338};
339
340struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
342
343 LogicalResult
344 matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
345 ConversionPatternRewriter &rewriter) const override {
346 auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
347 if (!addressOp)
348 return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
349
350 auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
351 auto srcVarOp = cast<spirv::GlobalVariableOp>(
352 SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
353 auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
354
355 spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
356 spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
357
358 if (srcElemType == dstElemType ||
359 areSameBitwidthScalarType(srcElemType, dstElemType)) {
360 // We have the same bitwidth for source and destination element types.
361 // Thie indices keep the same.
362 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
363 acOp, adaptor.getBasePtr(), adaptor.getIndices());
364 return success();
365 }
366
367 Location loc = acOp.getLoc();
368
369 if (srcElemType.isIntOrFloat() && isa<VectorType>(dstElemType)) {
370 // The source indices are for a buffer with scalar element types. Rewrite
371 // them into a buffer with vector element types. We need to scale the last
372 // index for the vector as a whole, then add one level of index for inside
373 // the vector.
374 int srcNumBytes = *srcElemType.getSizeInBytes();
375 int dstNumBytes = *dstElemType.getSizeInBytes();
376 assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
377
378 auto indices = llvm::to_vector<4>(acOp.getIndices());
379 Value oldIndex = indices.back();
380 Type indexType = oldIndex.getType();
381
382 int ratio = dstNumBytes / srcNumBytes;
383 auto ratioValue = spirv::ConstantOp::create(
384 rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
385
386 indices.back() =
387 spirv::SDivOp::create(rewriter, loc, indexType, oldIndex, ratioValue);
388 indices.push_back(spirv::SModOp::create(rewriter, loc, indexType,
389 oldIndex, ratioValue));
390
391 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
392 acOp, adaptor.getBasePtr(), indices);
393 return success();
394 }
395
396 if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
397 (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
398 // The source indices are for a buffer with larger bitwidth scalar/vector
399 // element types. Rewrite them into a buffer with smaller bitwidth element
400 // types. We only need to scale the last index.
401 int srcNumBytes = *srcElemType.getSizeInBytes();
402 int dstNumBytes = *dstElemType.getSizeInBytes();
403 assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
405 auto indices = llvm::to_vector<4>(acOp.getIndices());
406 Value oldIndex = indices.back();
407 Type indexType = oldIndex.getType();
409 int ratio = srcNumBytes / dstNumBytes;
410 auto ratioValue = spirv::ConstantOp::create(
411 rewriter, loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
412
413 indices.back() =
414 spirv::IMulOp::create(rewriter, loc, indexType, oldIndex, ratioValue);
415
416 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
417 acOp, adaptor.getBasePtr(), indices);
418 return success();
420
421 return rewriter.notifyMatchFailure(
422 acOp, "unsupported src/dst types for spirv.AccessChain");
423 }
424};
426struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
429 LogicalResult
430 matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
431 ConversionPatternRewriter &rewriter) const override {
432 auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
433 auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
434 auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
435 auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
436
437 Location loc = loadOp.getLoc();
438 auto newLoadOp = spirv::LoadOp::create(rewriter, loc, adaptor.getPtr());
439 if (srcElemType == dstElemType) {
440 rewriter.replaceOp(loadOp, newLoadOp->getResults());
441 return success();
442 }
443
444 if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
445 auto castOp = spirv::BitcastOp::create(rewriter, loc, srcElemType,
446 newLoadOp.getValue());
447 rewriter.replaceOp(loadOp, castOp->getResults());
448
449 return success();
450 }
451
452 if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
453 (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
454 // The source and destination have scalar types of different bitwidths, or
455 // vector types of different component counts. For such cases, we load
456 // multiple smaller bitwidth values and construct a larger bitwidth one.
457
458 int srcNumBytes = *srcElemType.getSizeInBytes();
459 int dstNumBytes = *dstElemType.getSizeInBytes();
460 assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
461 int ratio = srcNumBytes / dstNumBytes;
462 if (ratio > 4)
463 return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
464
465 SmallVector<Value> components;
466 components.reserve(ratio);
467 components.push_back(newLoadOp);
468
469 auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
470 if (!acOp)
471 return rewriter.notifyMatchFailure(loadOp, "ptr not spirv.AccessChain");
472
473 auto i32Type = rewriter.getI32Type();
474 Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
475 auto indices = llvm::to_vector<4>(acOp.getIndices());
476 for (int i = 1; i < ratio; ++i) {
477 // Load all subsequent components belonging to this element.
478 indices.back() = spirv::IAddOp::create(rewriter, loc, i32Type,
479 indices.back(), oneValue);
480 auto componentAcOp = spirv::AccessChainOp::create(
481 rewriter, loc, acOp.getBasePtr(), indices);
482 // Assuming little endian, this reads lower-ordered bits of the number
483 // to lower-numbered components of the vector.
484 components.push_back(
485 spirv::LoadOp::create(rewriter, loc, componentAcOp));
486 }
487
488 // Create a vector of the components and then cast back to the larger
489 // bitwidth element type. For spirv.bitcast, the lower-numbered components
490 // of the vector map to lower-ordered bits of the larger bitwidth element
491 // type.
492
493 Type vectorType = srcElemType;
494 if (!isa<VectorType>(srcElemType))
495 vectorType = VectorType::get({ratio}, dstElemType);
496
497 // If both the source and destination are vector types, we need to make
498 // sure the scalar type is the same for composite construction later.
499 if (auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
500 if (auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
501 if (srcElemVecType.getElementType() !=
502 dstElemVecType.getElementType()) {
503 int64_t count =
504 dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
505
506 // Make sure not to create 1-element vectors, which are illegal in
507 // SPIR-V.
508 Type castType = srcElemVecType.getElementType();
509 if (count > 1)
510 castType = VectorType::get({count}, castType);
511
512 for (Value &c : components)
513 c = spirv::BitcastOp::create(rewriter, loc, castType, c);
514 }
515 }
516 Value vectorValue = spirv::CompositeConstructOp::create(
517 rewriter, loc, vectorType, components);
518
519 if (!isa<VectorType>(srcElemType))
520 vectorValue =
521 spirv::BitcastOp::create(rewriter, loc, srcElemType, vectorValue);
522 rewriter.replaceOp(loadOp, vectorValue);
523 return success();
524 }
525
526 return rewriter.notifyMatchFailure(
527 loadOp, "unsupported src/dst types for spirv.Load");
528 }
529};
530
531struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
533
534 LogicalResult
535 matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
536 ConversionPatternRewriter &rewriter) const override {
537 auto srcElemType =
538 cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
539 auto dstElemType =
540 cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
541 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
542 return rewriter.notifyMatchFailure(storeOp, "not scalar type");
543 if (!areSameBitwidthScalarType(srcElemType, dstElemType))
544 return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
545
546 Location loc = storeOp.getLoc();
547 Value value = adaptor.getValue();
548 if (srcElemType != dstElemType)
549 value = spirv::BitcastOp::create(rewriter, loc, dstElemType, value);
550 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(),
551 value, storeOp->getAttrs());
552 return success();
553 }
554};
555
556//===----------------------------------------------------------------------===//
557// Pass
558//===----------------------------------------------------------------------===//
559
560namespace {
561class UnifyAliasedResourcePass final
562 : public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
563 UnifyAliasedResourcePass> {
564public:
565 explicit UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv)
566 : getTargetEnvFn(std::move(getTargetEnv)) {}
567
568 void runOnOperation() override;
569
570private:
571 spirv::GetTargetEnvFn getTargetEnvFn;
572};
573
574void UnifyAliasedResourcePass::runOnOperation() {
575 spirv::ModuleOp moduleOp = getOperation();
576 MLIRContext *context = &getContext();
577
578 if (getTargetEnvFn) {
579 // This pass is only needed for targeting WebGPU, Metal, or layering
580 // Vulkan on Metal via MoltenVK, where we need to translate SPIR-V into
581 // WGSL or MSL. The translation has limitations.
582 spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
583 spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
584 bool isVulkanOnAppleDevices =
585 clientAPI == spirv::ClientAPI::Vulkan &&
586 targetEnv.getVendorID() == spirv::Vendor::Apple;
587 if (clientAPI != spirv::ClientAPI::WebGPU &&
588 clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
589 return;
590 }
591
592 // Analyze aliased resources first.
593 ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
594
595 ConversionTarget target(*context);
596 target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
597 spirv::AccessChainOp, spirv::LoadOp,
598 spirv::StoreOp>(
599 [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
600 target.addLegalDialect<spirv::SPIRVDialect>();
601
602 // Run patterns to rewrite usages of non-canonical resources.
603 RewritePatternSet patterns(context);
604 patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
605 ConvertLoad, ConvertStore>(analysis, context);
606 if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
607 return signalPassFailure();
608
609 // Drop aliased attribute if we only have one single bound resource for a
610 // descriptor. We need to re-collect the map here given in the above the
611 // conversion is best effort; certain sets may not be converted.
612 AliasedResourceMap resourceMap =
613 collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
614 for (const auto &dr : resourceMap) {
615 const auto &resources = dr.second;
616 if (resources.size() == 1)
617 resources.front()->removeAttr("aliased");
618 }
619}
620} // namespace
621
622std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
624 return std::make_unique<UnifyAliasedResourcePass>(std::move(getTargetEnv));
625}
return success()
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
static std::optional< int > deduceCanonicalResource(ArrayRef< spirv::SPIRVType > types)
Given a list of resource element types, returns the index of the canonical resource that all resource...
DenseMap< Descriptor, SmallVector< spirv::GlobalVariableOp > > AliasedResourceMap
static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp)
Collects all aliased resources in the given SPIR-V moduleOp.
static Type getRuntimeArrayElementType(Type type)
Returns the element type if the given type is a runtime array resource: !spirv.ptr<!...
static bool areSameBitwidthScalarType(Type a, Type b)
std::pair< uint32_t, uint32_t > Descriptor
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
const ResourceAliasAnalysis & analysis
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Vendor getVendorID() const
Returns the vendor ID.
ClientAPI getClientAPI() const
Returns the client API.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition Remarks.h:567
std::unique_ptr< OperationPass< spirv::ModuleOp > > createUnifyAliasedResourcePass(GetTargetEnvFn getTargetEnv=nullptr)
std::function< spirv::TargetEnvAttr(spirv::ModuleOp)> GetTargetEnvFn
Creates an operation pass that unifies access of multiple aliased resources into access of one single...
Definition Passes.h:36
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override