MLIR  21.0.0git
UnifyAliasedResourcePass.cpp
Go to the documentation of this file.
1 //===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass that unifies access of multiple aliased resources
10 // into access of one single resource.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
20 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/SymbolTable.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include <iterator>
28 
29 namespace mlir {
30 namespace spirv {
31 #define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
32 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
33 } // namespace spirv
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 //===----------------------------------------------------------------------===//
39 // Utility functions
40 //===----------------------------------------------------------------------===//
41 
42 using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
45 
46 /// Collects all aliased resources in the given SPIR-V `moduleOp`.
47 static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
48  AliasedResourceMap aliasedResources;
49  moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
50  if (varOp->getAttrOfType<UnitAttr>("aliased")) {
51  std::optional<uint32_t> set = varOp.getDescriptorSet();
52  std::optional<uint32_t> binding = varOp.getBinding();
53  if (set && binding)
54  aliasedResources[{*set, *binding}].push_back(varOp);
55  }
56  });
57  return aliasedResources;
58 }
59 
60 /// Returns the element type if the given `type` is a runtime array resource:
61 /// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type
62 /// otherwise.
64  auto ptrType = dyn_cast<spirv::PointerType>(type);
65  if (!ptrType)
66  return {};
67 
68  auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
69  if (!structType || structType.getNumElements() != 1)
70  return {};
71 
72  auto rtArrayType =
73  dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0));
74  if (!rtArrayType)
75  return {};
76 
77  return rtArrayType.getElementType();
78 }
79 
80 /// Given a list of resource element `types`, returns the index of the canonical
81 /// resource that all resources should be unified into. Returns std::nullopt if
82 /// unable to unify.
83 static std::optional<int>
85  // scalarNumBits: contains all resources' scalar types' bit counts.
86  // vectorNumBits: only contains resources whose element types are vectors.
87  // vectorIndices: each vector's original index in `types`.
88  SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
89  scalarNumBits.reserve(types.size());
90  vectorNumBits.reserve(types.size());
91  vectorIndices.reserve(types.size());
92 
93  for (const auto &indexedTypes : llvm::enumerate(types)) {
94  spirv::SPIRVType type = indexedTypes.value();
95  assert(type.isScalarOrVector());
96  if (auto vectorType = dyn_cast<VectorType>(type)) {
97  if (vectorType.getNumElements() % 2 != 0)
98  return std::nullopt; // Odd-sized vector has special layout
99  // requirements.
100 
101  std::optional<int64_t> numBytes = type.getSizeInBytes();
102  if (!numBytes)
103  return std::nullopt;
104 
105  scalarNumBits.push_back(
106  vectorType.getElementType().getIntOrFloatBitWidth());
107  vectorNumBits.push_back(*numBytes * 8);
108  vectorIndices.push_back(indexedTypes.index());
109  } else {
110  scalarNumBits.push_back(type.getIntOrFloatBitWidth());
111  }
112  }
113 
114  if (!vectorNumBits.empty()) {
115  // Choose the *vector* with the smallest bitwidth as the canonical resource,
116  // so that we can still keep vectorized load/store and avoid partial updates
117  // to large vectors.
118  auto *minVal = llvm::min_element(vectorNumBits);
119  // Make sure that the canonical resource's bitwidth is divisible by others.
120  // With out this, we cannot properly adjust the index later.
121  if (llvm::any_of(vectorNumBits,
122  [&](int bits) { return bits % *minVal != 0; }))
123  return std::nullopt;
124 
125  // Require all scalar type bit counts to be a multiple of the chosen
126  // vector's primitive type to avoid reading/writing subcomponents.
127  int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
128  int baseNumBits = scalarNumBits[index];
129  if (llvm::any_of(scalarNumBits,
130  [&](int bits) { return bits % baseNumBits != 0; }))
131  return std::nullopt;
132 
133  return index;
134  }
135 
136  // All element types are scalars. Then choose the smallest bitwidth as the
137  // cannonical resource to avoid subcomponent load/store.
138  auto *minVal = llvm::min_element(scalarNumBits);
139  if (llvm::any_of(scalarNumBits,
140  [minVal](int64_t bit) { return bit % *minVal != 0; }))
141  return std::nullopt;
142  return std::distance(scalarNumBits.begin(), minVal);
143 }
144 
146  return a.isIntOrFloat() && b.isIntOrFloat() &&
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // Analysis
152 //===----------------------------------------------------------------------===//
153 
154 namespace {
155 /// A class for analyzing aliased resources.
156 ///
157 /// Resources are expected to be spirv.GlobalVarible that has a descriptor set
158 /// and binding number. Such resources are of the type
159 /// `!spirv.ptr<!spirv.struct<...>>` per Vulkan requirements.
160 ///
161 /// Right now, we only support the case that there is a single runtime array
162 /// inside the struct.
163 class ResourceAliasAnalysis {
164 public:
165  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis)
166 
167  explicit ResourceAliasAnalysis(Operation *);
168 
169  /// Returns true if the given `op` can be rewritten to use a canonical
170  /// resource.
171  bool shouldUnify(Operation *op) const;
172 
173  /// Returns all descriptors and their corresponding aliased resources.
174  const AliasedResourceMap &getResourceMap() const { return resourceMap; }
175 
176  /// Returns the canonical resource for the given descriptor/variable.
177  spirv::GlobalVariableOp
178  getCanonicalResource(const Descriptor &descriptor) const;
179  spirv::GlobalVariableOp
180  getCanonicalResource(spirv::GlobalVariableOp varOp) const;
181 
182  /// Returns the element type for the given variable.
183  spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
184 
185 private:
186  /// Given the descriptor and aliased resources bound to it, analyze whether we
187  /// can unify them and record if so.
188  void recordIfUnifiable(const Descriptor &descriptor,
190 
191  /// Mapping from a descriptor to all aliased resources bound to it.
192  AliasedResourceMap resourceMap;
193 
194  /// Mapping from a descriptor to the chosen canonical resource.
196 
197  /// Mapping from an aliased resource to its descriptor.
199 
200  /// Mapping from an aliased resource to its element (scalar/vector) type.
202 };
203 } // namespace
204 
205 ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
206  // Collect all aliased resources first and put them into different sets
207  // according to the descriptor.
208  AliasedResourceMap aliasedResources =
209  collectAliasedResources(cast<spirv::ModuleOp>(root));
210 
211  // For each resource set, analyze whether we can unify; if so, try to identify
212  // a canonical resource, whose element type has the largest bitwidth.
213  for (const auto &descriptorResource : aliasedResources) {
214  recordIfUnifiable(descriptorResource.first, descriptorResource.second);
215  }
216 }
217 
218 bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
219  if (!op)
220  return false;
221 
222  if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
223  auto canonicalOp = getCanonicalResource(varOp);
224  return canonicalOp && varOp != canonicalOp;
225  }
226  if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
227  auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
228  auto *varOp =
229  SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
230  return shouldUnify(varOp);
231  }
232 
233  if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
234  return shouldUnify(acOp.getBasePtr().getDefiningOp());
235  if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
236  return shouldUnify(loadOp.getPtr().getDefiningOp());
237  if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
238  return shouldUnify(storeOp.getPtr().getDefiningOp());
239 
240  return false;
241 }
242 
243 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
244  const Descriptor &descriptor) const {
245  auto varIt = canonicalResourceMap.find(descriptor);
246  if (varIt == canonicalResourceMap.end())
247  return {};
248  return varIt->second;
249 }
250 
251 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
252  spirv::GlobalVariableOp varOp) const {
253  auto descriptorIt = descriptorMap.find(varOp);
254  if (descriptorIt == descriptorMap.end())
255  return {};
256  return getCanonicalResource(descriptorIt->second);
257 }
258 
260 ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
261  auto it = elementTypeMap.find(varOp);
262  if (it == elementTypeMap.end())
263  return {};
264  return it->second;
265 }
266 
267 void ResourceAliasAnalysis::recordIfUnifiable(
268  const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
269  // Collect the element types for all resources in the current set.
270  SmallVector<spirv::SPIRVType> elementTypes;
271  for (spirv::GlobalVariableOp resource : resources) {
272  Type elementType = getRuntimeArrayElementType(resource.getType());
273  if (!elementType)
274  return; // Unexpected resource variable type.
275 
276  auto type = cast<spirv::SPIRVType>(elementType);
277  if (!type.isScalarOrVector())
278  return; // Unexpected resource element type.
279 
280  elementTypes.push_back(type);
281  }
282 
283  std::optional<int> index = deduceCanonicalResource(elementTypes);
284  if (!index)
285  return;
286 
287  // Update internal data structures for later use.
288  resourceMap[descriptor].assign(resources.begin(), resources.end());
289  canonicalResourceMap[descriptor] = resources[*index];
290  for (const auto &resource : llvm::enumerate(resources)) {
291  descriptorMap[resource.value()] = descriptor;
292  elementTypeMap[resource.value()] = elementTypes[resource.index()];
293  }
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // Patterns
298 //===----------------------------------------------------------------------===//
299 
300 template <typename OpTy>
302 public:
303  ConvertAliasResource(const ResourceAliasAnalysis &analysis,
304  MLIRContext *context, PatternBenefit benefit = 1)
305  : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
306 
307 protected:
308  const ResourceAliasAnalysis &analysis;
309 };
310 
311 struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> {
313 
314  LogicalResult
315  matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
316  ConversionPatternRewriter &rewriter) const override {
317  // Just remove the aliased resource. Users will be rewritten to use the
318  // canonical one.
319  rewriter.eraseOp(varOp);
320  return success();
321  }
322 };
323 
324 struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
326 
327  LogicalResult
328  matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
329  ConversionPatternRewriter &rewriter) const override {
330  // Rewrite the AddressOf op to get the address of the canoncical resource.
331  auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
332  auto srcVarOp = cast<spirv::GlobalVariableOp>(
333  SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
334  auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
335  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
336  return success();
337  }
338 };
339 
340 struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
342 
343  LogicalResult
344  matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
345  ConversionPatternRewriter &rewriter) const override {
346  auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
347  if (!addressOp)
348  return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
349 
350  auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
351  auto srcVarOp = cast<spirv::GlobalVariableOp>(
352  SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
353  auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
354 
355  spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
356  spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
357 
358  if (srcElemType == dstElemType ||
359  areSameBitwidthScalarType(srcElemType, dstElemType)) {
360  // We have the same bitwidth for source and destination element types.
361  // Thie indices keep the same.
362  rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
363  acOp, adaptor.getBasePtr(), adaptor.getIndices());
364  return success();
365  }
366 
367  Location loc = acOp.getLoc();
368 
369  if (srcElemType.isIntOrFloat() && isa<VectorType>(dstElemType)) {
370  // The source indices are for a buffer with scalar element types. Rewrite
371  // them into a buffer with vector element types. We need to scale the last
372  // index for the vector as a whole, then add one level of index for inside
373  // the vector.
374  int srcNumBytes = *srcElemType.getSizeInBytes();
375  int dstNumBytes = *dstElemType.getSizeInBytes();
376  assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
377 
378  auto indices = llvm::to_vector<4>(acOp.getIndices());
379  Value oldIndex = indices.back();
380  Type indexType = oldIndex.getType();
381 
382  int ratio = dstNumBytes / srcNumBytes;
383  auto ratioValue = rewriter.create<spirv::ConstantOp>(
384  loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
385 
386  indices.back() =
387  rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
388  indices.push_back(
389  rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
390 
391  rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
392  acOp, adaptor.getBasePtr(), indices);
393  return success();
394  }
395 
396  if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
397  (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
398  // The source indices are for a buffer with larger bitwidth scalar/vector
399  // element types. Rewrite them into a buffer with smaller bitwidth element
400  // types. We only need to scale the last index.
401  int srcNumBytes = *srcElemType.getSizeInBytes();
402  int dstNumBytes = *dstElemType.getSizeInBytes();
403  assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
404 
405  auto indices = llvm::to_vector<4>(acOp.getIndices());
406  Value oldIndex = indices.back();
407  Type indexType = oldIndex.getType();
408 
409  int ratio = srcNumBytes / dstNumBytes;
410  auto ratioValue = rewriter.create<spirv::ConstantOp>(
411  loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
412 
413  indices.back() =
414  rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
415 
416  rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
417  acOp, adaptor.getBasePtr(), indices);
418  return success();
419  }
420 
421  return rewriter.notifyMatchFailure(
422  acOp, "unsupported src/dst types for spirv.AccessChain");
423  }
424 };
425 
426 struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
428 
429  LogicalResult
430  matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
431  ConversionPatternRewriter &rewriter) const override {
432  auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
433  auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
434  auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
435  auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
436 
437  Location loc = loadOp.getLoc();
438  auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
439  if (srcElemType == dstElemType) {
440  rewriter.replaceOp(loadOp, newLoadOp->getResults());
441  return success();
442  }
443 
444  if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
445  auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
446  newLoadOp.getValue());
447  rewriter.replaceOp(loadOp, castOp->getResults());
448 
449  return success();
450  }
451 
452  if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
453  (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
454  // The source and destination have scalar types of different bitwidths, or
455  // vector types of different component counts. For such cases, we load
456  // multiple smaller bitwidth values and construct a larger bitwidth one.
457 
458  int srcNumBytes = *srcElemType.getSizeInBytes();
459  int dstNumBytes = *dstElemType.getSizeInBytes();
460  assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
461  int ratio = srcNumBytes / dstNumBytes;
462  if (ratio > 4)
463  return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
464 
465  SmallVector<Value> components;
466  components.reserve(ratio);
467  components.push_back(newLoadOp);
468 
469  auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
470  if (!acOp)
471  return rewriter.notifyMatchFailure(loadOp, "ptr not spirv.AccessChain");
472 
473  auto i32Type = rewriter.getI32Type();
474  Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
475  auto indices = llvm::to_vector<4>(acOp.getIndices());
476  for (int i = 1; i < ratio; ++i) {
477  // Load all subsequent components belonging to this element.
478  indices.back() = rewriter.create<spirv::IAddOp>(
479  loc, i32Type, indices.back(), oneValue);
480  auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
481  loc, acOp.getBasePtr(), indices);
482  // Assuming little endian, this reads lower-ordered bits of the number
483  // to lower-numbered components of the vector.
484  components.push_back(
485  rewriter.create<spirv::LoadOp>(loc, componentAcOp));
486  }
487 
488  // Create a vector of the components and then cast back to the larger
489  // bitwidth element type. For spirv.bitcast, the lower-numbered components
490  // of the vector map to lower-ordered bits of the larger bitwidth element
491  // type.
492 
493  Type vectorType = srcElemType;
494  if (!isa<VectorType>(srcElemType))
495  vectorType = VectorType::get({ratio}, dstElemType);
496 
497  // If both the source and destination are vector types, we need to make
498  // sure the scalar type is the same for composite construction later.
499  if (auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
500  if (auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
501  if (srcElemVecType.getElementType() !=
502  dstElemVecType.getElementType()) {
503  int64_t count =
504  dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
505 
506  // Make sure not to create 1-element vectors, which are illegal in
507  // SPIR-V.
508  Type castType = srcElemVecType.getElementType();
509  if (count > 1)
510  castType = VectorType::get({count}, castType);
511 
512  for (Value &c : components)
513  c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
514  }
515  }
516  Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
517  loc, vectorType, components);
518 
519  if (!isa<VectorType>(srcElemType))
520  vectorValue =
521  rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
522  rewriter.replaceOp(loadOp, vectorValue);
523  return success();
524  }
525 
526  return rewriter.notifyMatchFailure(
527  loadOp, "unsupported src/dst types for spirv.Load");
528  }
529 };
530 
531 struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
533 
534  LogicalResult
535  matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
536  ConversionPatternRewriter &rewriter) const override {
537  auto srcElemType =
538  cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
539  auto dstElemType =
540  cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
541  if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
542  return rewriter.notifyMatchFailure(storeOp, "not scalar type");
543  if (!areSameBitwidthScalarType(srcElemType, dstElemType))
544  return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
545 
546  Location loc = storeOp.getLoc();
547  Value value = adaptor.getValue();
548  if (srcElemType != dstElemType)
549  value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
550  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(),
551  value, storeOp->getAttrs());
552  return success();
553  }
554 };
555 
556 //===----------------------------------------------------------------------===//
557 // Pass
558 //===----------------------------------------------------------------------===//
559 
560 namespace {
561 class UnifyAliasedResourcePass final
562  : public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
563  UnifyAliasedResourcePass> {
564 public:
565  explicit UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv)
566  : getTargetEnvFn(std::move(getTargetEnv)) {}
567 
568  void runOnOperation() override;
569 
570 private:
571  spirv::GetTargetEnvFn getTargetEnvFn;
572 };
573 
574 void UnifyAliasedResourcePass::runOnOperation() {
575  spirv::ModuleOp moduleOp = getOperation();
576  MLIRContext *context = &getContext();
577 
578  if (getTargetEnvFn) {
579  // This pass is only needed for targeting WebGPU, Metal, or layering
580  // Vulkan on Metal via MoltenVK, where we need to translate SPIR-V into
581  // WGSL or MSL. The translation has limitations.
582  spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
583  spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
584  bool isVulkanOnAppleDevices =
585  clientAPI == spirv::ClientAPI::Vulkan &&
586  targetEnv.getVendorID() == spirv::Vendor::Apple;
587  if (clientAPI != spirv::ClientAPI::WebGPU &&
588  clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
589  return;
590  }
591 
592  // Analyze aliased resources first.
593  ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
594 
595  ConversionTarget target(*context);
596  target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
597  spirv::AccessChainOp, spirv::LoadOp,
598  spirv::StoreOp>(
599  [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
600  target.addLegalDialect<spirv::SPIRVDialect>();
601 
602  // Run patterns to rewrite usages of non-canonical resources.
603  RewritePatternSet patterns(context);
605  ConvertLoad, ConvertStore>(analysis, context);
606  if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
607  return signalPassFailure();
608 
609  // Drop aliased attribute if we only have one single bound resource for a
610  // descriptor. We need to re-collect the map here given in the above the
611  // conversion is best effort; certain sets may not be converted.
612  AliasedResourceMap resourceMap =
613  collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
614  for (const auto &dr : resourceMap) {
615  const auto &resources = dr.second;
616  if (resources.size() == 1)
617  resources.front()->removeAttr("aliased");
618  }
619 }
620 } // namespace
621 
622 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
624  return std::make_unique<UnifyAliasedResourcePass>(std::move(getTargetEnv));
625 }
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:186
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp)
Collects all aliased resources in the given SPIR-V moduleOp.
static Type getRuntimeArrayElementType(Type type)
Returns the element type if the given type is a runtime array resource: !spirv.ptr<!...
static bool areSameBitwidthScalarType(Type a, Type b)
static std::optional< int > deduceCanonicalResource(ArrayRef< spirv::SPIRVType > types)
Given a list of resource element types, returns the index of the canonical resource that all resource...
std::pair< uint32_t, uint32_t > Descriptor
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
const ResourceAliasAnalysis & analysis
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
IntegerType getI32Type()
Definition: Builders.cpp:62
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:768
An attribute that specifies the target version, allowed extensions and capabilities,...
Vendor getVendorID() const
Returns the vendor ID.
ClientAPI getClientAPI() const
Returns the client API.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::unique_ptr< OperationPass< spirv::ModuleOp > > createUnifyAliasedResourcePass(GetTargetEnvFn getTargetEnv=nullptr)
std::function< spirv::TargetEnvAttr(spirv::ModuleOp)> GetTargetEnvFn
Creates an operation pass that unifies access of multiple aliased resources into access of one single...
Definition: Passes.h:36
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LogicalResult matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override