MLIR  20.0.0git
UnifyAliasedResourcePass.cpp
Go to the documentation of this file.
1 //===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass that unifies access of multiple aliased resources
10 // into access of one single resource.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
20 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/SymbolTable.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/Debug.h"
28 #include <algorithm>
29 #include <iterator>
30 
31 namespace mlir {
32 namespace spirv {
33 #define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
34 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
35 } // namespace spirv
36 } // namespace mlir
37 
38 #define DEBUG_TYPE "spirv-unify-aliased-resource"
39 
40 using namespace mlir;
41 
42 //===----------------------------------------------------------------------===//
43 // Utility functions
44 //===----------------------------------------------------------------------===//
45 
46 using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
49 
50 /// Collects all aliased resources in the given SPIR-V `moduleOp`.
51 static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
52  AliasedResourceMap aliasedResources;
53  moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
54  if (varOp->getAttrOfType<UnitAttr>("aliased")) {
55  std::optional<uint32_t> set = varOp.getDescriptorSet();
56  std::optional<uint32_t> binding = varOp.getBinding();
57  if (set && binding)
58  aliasedResources[{*set, *binding}].push_back(varOp);
59  }
60  });
61  return aliasedResources;
62 }
63 
64 /// Returns the element type if the given `type` is a runtime array resource:
65 /// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type
66 /// otherwise.
68  auto ptrType = dyn_cast<spirv::PointerType>(type);
69  if (!ptrType)
70  return {};
71 
72  auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
73  if (!structType || structType.getNumElements() != 1)
74  return {};
75 
76  auto rtArrayType =
77  dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0));
78  if (!rtArrayType)
79  return {};
80 
81  return rtArrayType.getElementType();
82 }
83 
84 /// Given a list of resource element `types`, returns the index of the canonical
85 /// resource that all resources should be unified into. Returns std::nullopt if
86 /// unable to unify.
87 static std::optional<int>
89  // scalarNumBits: contains all resources' scalar types' bit counts.
90  // vectorNumBits: only contains resources whose element types are vectors.
91  // vectorIndices: each vector's original index in `types`.
92  SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
93  scalarNumBits.reserve(types.size());
94  vectorNumBits.reserve(types.size());
95  vectorIndices.reserve(types.size());
96 
97  for (const auto &indexedTypes : llvm::enumerate(types)) {
98  spirv::SPIRVType type = indexedTypes.value();
99  assert(type.isScalarOrVector());
100  if (auto vectorType = dyn_cast<VectorType>(type)) {
101  if (vectorType.getNumElements() % 2 != 0)
102  return std::nullopt; // Odd-sized vector has special layout
103  // requirements.
104 
105  std::optional<int64_t> numBytes = type.getSizeInBytes();
106  if (!numBytes)
107  return std::nullopt;
108 
109  scalarNumBits.push_back(
110  vectorType.getElementType().getIntOrFloatBitWidth());
111  vectorNumBits.push_back(*numBytes * 8);
112  vectorIndices.push_back(indexedTypes.index());
113  } else {
114  scalarNumBits.push_back(type.getIntOrFloatBitWidth());
115  }
116  }
117 
118  if (!vectorNumBits.empty()) {
119  // Choose the *vector* with the smallest bitwidth as the canonical resource,
120  // so that we can still keep vectorized load/store and avoid partial updates
121  // to large vectors.
122  auto *minVal = llvm::min_element(vectorNumBits);
123  // Make sure that the canonical resource's bitwidth is divisible by others.
124  // With out this, we cannot properly adjust the index later.
125  if (llvm::any_of(vectorNumBits,
126  [&](int bits) { return bits % *minVal != 0; }))
127  return std::nullopt;
128 
129  // Require all scalar type bit counts to be a multiple of the chosen
130  // vector's primitive type to avoid reading/writing subcomponents.
131  int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
132  int baseNumBits = scalarNumBits[index];
133  if (llvm::any_of(scalarNumBits,
134  [&](int bits) { return bits % baseNumBits != 0; }))
135  return std::nullopt;
136 
137  return index;
138  }
139 
140  // All element types are scalars. Then choose the smallest bitwidth as the
141  // cannonical resource to avoid subcomponent load/store.
142  auto *minVal = llvm::min_element(scalarNumBits);
143  if (llvm::any_of(scalarNumBits,
144  [minVal](int64_t bit) { return bit % *minVal != 0; }))
145  return std::nullopt;
146  return std::distance(scalarNumBits.begin(), minVal);
147 }
148 
150  return a.isIntOrFloat() && b.isIntOrFloat() &&
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // Analysis
156 //===----------------------------------------------------------------------===//
157 
158 namespace {
159 /// A class for analyzing aliased resources.
160 ///
161 /// Resources are expected to be spirv.GlobalVarible that has a descriptor set
162 /// and binding number. Such resources are of the type
163 /// `!spirv.ptr<!spirv.struct<...>>` per Vulkan requirements.
164 ///
165 /// Right now, we only support the case that there is a single runtime array
166 /// inside the struct.
167 class ResourceAliasAnalysis {
168 public:
169  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis)
170 
171  explicit ResourceAliasAnalysis(Operation *);
172 
173  /// Returns true if the given `op` can be rewritten to use a canonical
174  /// resource.
175  bool shouldUnify(Operation *op) const;
176 
177  /// Returns all descriptors and their corresponding aliased resources.
178  const AliasedResourceMap &getResourceMap() const { return resourceMap; }
179 
180  /// Returns the canonical resource for the given descriptor/variable.
181  spirv::GlobalVariableOp
182  getCanonicalResource(const Descriptor &descriptor) const;
183  spirv::GlobalVariableOp
184  getCanonicalResource(spirv::GlobalVariableOp varOp) const;
185 
186  /// Returns the element type for the given variable.
187  spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
188 
189 private:
190  /// Given the descriptor and aliased resources bound to it, analyze whether we
191  /// can unify them and record if so.
192  void recordIfUnifiable(const Descriptor &descriptor,
194 
195  /// Mapping from a descriptor to all aliased resources bound to it.
196  AliasedResourceMap resourceMap;
197 
198  /// Mapping from a descriptor to the chosen canonical resource.
200 
201  /// Mapping from an aliased resource to its descriptor.
203 
204  /// Mapping from an aliased resource to its element (scalar/vector) type.
206 };
207 } // namespace
208 
209 ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
210  // Collect all aliased resources first and put them into different sets
211  // according to the descriptor.
212  AliasedResourceMap aliasedResources =
213  collectAliasedResources(cast<spirv::ModuleOp>(root));
214 
215  // For each resource set, analyze whether we can unify; if so, try to identify
216  // a canonical resource, whose element type has the largest bitwidth.
217  for (const auto &descriptorResource : aliasedResources) {
218  recordIfUnifiable(descriptorResource.first, descriptorResource.second);
219  }
220 }
221 
222 bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
223  if (!op)
224  return false;
225 
226  if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
227  auto canonicalOp = getCanonicalResource(varOp);
228  return canonicalOp && varOp != canonicalOp;
229  }
230  if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
231  auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
232  auto *varOp =
233  SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
234  return shouldUnify(varOp);
235  }
236 
237  if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
238  return shouldUnify(acOp.getBasePtr().getDefiningOp());
239  if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
240  return shouldUnify(loadOp.getPtr().getDefiningOp());
241  if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
242  return shouldUnify(storeOp.getPtr().getDefiningOp());
243 
244  return false;
245 }
246 
247 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
248  const Descriptor &descriptor) const {
249  auto varIt = canonicalResourceMap.find(descriptor);
250  if (varIt == canonicalResourceMap.end())
251  return {};
252  return varIt->second;
253 }
254 
255 spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
256  spirv::GlobalVariableOp varOp) const {
257  auto descriptorIt = descriptorMap.find(varOp);
258  if (descriptorIt == descriptorMap.end())
259  return {};
260  return getCanonicalResource(descriptorIt->second);
261 }
262 
264 ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
265  auto it = elementTypeMap.find(varOp);
266  if (it == elementTypeMap.end())
267  return {};
268  return it->second;
269 }
270 
271 void ResourceAliasAnalysis::recordIfUnifiable(
272  const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
273  // Collect the element types for all resources in the current set.
274  SmallVector<spirv::SPIRVType> elementTypes;
275  for (spirv::GlobalVariableOp resource : resources) {
276  Type elementType = getRuntimeArrayElementType(resource.getType());
277  if (!elementType)
278  return; // Unexpected resource variable type.
279 
280  auto type = cast<spirv::SPIRVType>(elementType);
281  if (!type.isScalarOrVector())
282  return; // Unexpected resource element type.
283 
284  elementTypes.push_back(type);
285  }
286 
287  std::optional<int> index = deduceCanonicalResource(elementTypes);
288  if (!index)
289  return;
290 
291  // Update internal data structures for later use.
292  resourceMap[descriptor].assign(resources.begin(), resources.end());
293  canonicalResourceMap[descriptor] = resources[*index];
294  for (const auto &resource : llvm::enumerate(resources)) {
295  descriptorMap[resource.value()] = descriptor;
296  elementTypeMap[resource.value()] = elementTypes[resource.index()];
297  }
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // Patterns
302 //===----------------------------------------------------------------------===//
303 
304 template <typename OpTy>
306 public:
307  ConvertAliasResource(const ResourceAliasAnalysis &analysis,
308  MLIRContext *context, PatternBenefit benefit = 1)
309  : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
310 
311 protected:
312  const ResourceAliasAnalysis &analysis;
313 };
314 
315 struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> {
317 
318  LogicalResult
319  matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
320  ConversionPatternRewriter &rewriter) const override {
321  // Just remove the aliased resource. Users will be rewritten to use the
322  // canonical one.
323  rewriter.eraseOp(varOp);
324  return success();
325  }
326 };
327 
328 struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
330 
331  LogicalResult
332  matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
333  ConversionPatternRewriter &rewriter) const override {
334  // Rewrite the AddressOf op to get the address of the canoncical resource.
335  auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
336  auto srcVarOp = cast<spirv::GlobalVariableOp>(
337  SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
338  auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
339  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
340  return success();
341  }
342 };
343 
344 struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
346 
347  LogicalResult
348  matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
349  ConversionPatternRewriter &rewriter) const override {
350  auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
351  if (!addressOp)
352  return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
353 
354  auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
355  auto srcVarOp = cast<spirv::GlobalVariableOp>(
356  SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
357  auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
358 
359  spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
360  spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
361 
362  if (srcElemType == dstElemType ||
363  areSameBitwidthScalarType(srcElemType, dstElemType)) {
364  // We have the same bitwidth for source and destination element types.
365  // Thie indices keep the same.
366  rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
367  acOp, adaptor.getBasePtr(), adaptor.getIndices());
368  return success();
369  }
370 
371  Location loc = acOp.getLoc();
372 
373  if (srcElemType.isIntOrFloat() && isa<VectorType>(dstElemType)) {
374  // The source indices are for a buffer with scalar element types. Rewrite
375  // them into a buffer with vector element types. We need to scale the last
376  // index for the vector as a whole, then add one level of index for inside
377  // the vector.
378  int srcNumBytes = *srcElemType.getSizeInBytes();
379  int dstNumBytes = *dstElemType.getSizeInBytes();
380  assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
381 
382  auto indices = llvm::to_vector<4>(acOp.getIndices());
383  Value oldIndex = indices.back();
384  Type indexType = oldIndex.getType();
385 
386  int ratio = dstNumBytes / srcNumBytes;
387  auto ratioValue = rewriter.create<spirv::ConstantOp>(
388  loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
389 
390  indices.back() =
391  rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
392  indices.push_back(
393  rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
394 
395  rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
396  acOp, adaptor.getBasePtr(), indices);
397  return success();
398  }
399 
400  if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
401  (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
402  // The source indices are for a buffer with larger bitwidth scalar/vector
403  // element types. Rewrite them into a buffer with smaller bitwidth element
404  // types. We only need to scale the last index.
405  int srcNumBytes = *srcElemType.getSizeInBytes();
406  int dstNumBytes = *dstElemType.getSizeInBytes();
407  assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
408 
409  auto indices = llvm::to_vector<4>(acOp.getIndices());
410  Value oldIndex = indices.back();
411  Type indexType = oldIndex.getType();
412 
413  int ratio = srcNumBytes / dstNumBytes;
414  auto ratioValue = rewriter.create<spirv::ConstantOp>(
415  loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
416 
417  indices.back() =
418  rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
419 
420  rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
421  acOp, adaptor.getBasePtr(), indices);
422  return success();
423  }
424 
425  return rewriter.notifyMatchFailure(
426  acOp, "unsupported src/dst types for spirv.AccessChain");
427  }
428 };
429 
430 struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
432 
433  LogicalResult
434  matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
435  ConversionPatternRewriter &rewriter) const override {
436  auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
437  auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
438  auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
439  auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
440 
441  Location loc = loadOp.getLoc();
442  auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
443  if (srcElemType == dstElemType) {
444  rewriter.replaceOp(loadOp, newLoadOp->getResults());
445  return success();
446  }
447 
448  if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
449  auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
450  newLoadOp.getValue());
451  rewriter.replaceOp(loadOp, castOp->getResults());
452 
453  return success();
454  }
455 
456  if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
457  (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
458  // The source and destination have scalar types of different bitwidths, or
459  // vector types of different component counts. For such cases, we load
460  // multiple smaller bitwidth values and construct a larger bitwidth one.
461 
462  int srcNumBytes = *srcElemType.getSizeInBytes();
463  int dstNumBytes = *dstElemType.getSizeInBytes();
464  assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
465  int ratio = srcNumBytes / dstNumBytes;
466  if (ratio > 4)
467  return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
468 
469  SmallVector<Value> components;
470  components.reserve(ratio);
471  components.push_back(newLoadOp);
472 
473  auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
474  if (!acOp)
475  return rewriter.notifyMatchFailure(loadOp, "ptr not spirv.AccessChain");
476 
477  auto i32Type = rewriter.getI32Type();
478  Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
479  auto indices = llvm::to_vector<4>(acOp.getIndices());
480  for (int i = 1; i < ratio; ++i) {
481  // Load all subsequent components belonging to this element.
482  indices.back() = rewriter.create<spirv::IAddOp>(
483  loc, i32Type, indices.back(), oneValue);
484  auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
485  loc, acOp.getBasePtr(), indices);
486  // Assuming little endian, this reads lower-ordered bits of the number
487  // to lower-numbered components of the vector.
488  components.push_back(
489  rewriter.create<spirv::LoadOp>(loc, componentAcOp));
490  }
491 
492  // Create a vector of the components and then cast back to the larger
493  // bitwidth element type. For spirv.bitcast, the lower-numbered components
494  // of the vector map to lower-ordered bits of the larger bitwidth element
495  // type.
496 
497  Type vectorType = srcElemType;
498  if (!isa<VectorType>(srcElemType))
499  vectorType = VectorType::get({ratio}, dstElemType);
500 
501  // If both the source and destination are vector types, we need to make
502  // sure the scalar type is the same for composite construction later.
503  if (auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
504  if (auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
505  if (srcElemVecType.getElementType() !=
506  dstElemVecType.getElementType()) {
507  int64_t count =
508  dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
509 
510  // Make sure not to create 1-element vectors, which are illegal in
511  // SPIR-V.
512  Type castType = srcElemVecType.getElementType();
513  if (count > 1)
514  castType = VectorType::get({count}, castType);
515 
516  for (Value &c : components)
517  c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
518  }
519  }
520  Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
521  loc, vectorType, components);
522 
523  if (!isa<VectorType>(srcElemType))
524  vectorValue =
525  rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
526  rewriter.replaceOp(loadOp, vectorValue);
527  return success();
528  }
529 
530  return rewriter.notifyMatchFailure(
531  loadOp, "unsupported src/dst types for spirv.Load");
532  }
533 };
534 
535 struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
537 
538  LogicalResult
539  matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
540  ConversionPatternRewriter &rewriter) const override {
541  auto srcElemType =
542  cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
543  auto dstElemType =
544  cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
545  if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
546  return rewriter.notifyMatchFailure(storeOp, "not scalar type");
547  if (!areSameBitwidthScalarType(srcElemType, dstElemType))
548  return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
549 
550  Location loc = storeOp.getLoc();
551  Value value = adaptor.getValue();
552  if (srcElemType != dstElemType)
553  value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
554  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(),
555  value, storeOp->getAttrs());
556  return success();
557  }
558 };
559 
560 //===----------------------------------------------------------------------===//
561 // Pass
562 //===----------------------------------------------------------------------===//
563 
564 namespace {
565 class UnifyAliasedResourcePass final
566  : public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
567  UnifyAliasedResourcePass> {
568 public:
569  explicit UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv)
570  : getTargetEnvFn(std::move(getTargetEnv)) {}
571 
572  void runOnOperation() override;
573 
574 private:
575  spirv::GetTargetEnvFn getTargetEnvFn;
576 };
577 
578 void UnifyAliasedResourcePass::runOnOperation() {
579  spirv::ModuleOp moduleOp = getOperation();
580  MLIRContext *context = &getContext();
581 
582  if (getTargetEnvFn) {
583  // This pass is only needed for targeting WebGPU, Metal, or layering
584  // Vulkan on Metal via MoltenVK, where we need to translate SPIR-V into
585  // WGSL or MSL. The translation has limitations.
586  spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
587  spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
588  bool isVulkanOnAppleDevices =
589  clientAPI == spirv::ClientAPI::Vulkan &&
590  targetEnv.getVendorID() == spirv::Vendor::Apple;
591  if (clientAPI != spirv::ClientAPI::WebGPU &&
592  clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
593  return;
594  }
595 
596  // Analyze aliased resources first.
597  ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
598 
599  ConversionTarget target(*context);
600  target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
601  spirv::AccessChainOp, spirv::LoadOp,
602  spirv::StoreOp>(
603  [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
604  target.addLegalDialect<spirv::SPIRVDialect>();
605 
606  // Run patterns to rewrite usages of non-canonical resources.
607  RewritePatternSet patterns(context);
609  ConvertLoad, ConvertStore>(analysis, context);
610  if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
611  return signalPassFailure();
612 
613  // Drop aliased attribute if we only have one single bound resource for a
614  // descriptor. We need to re-collect the map here given in the above the
615  // conversion is best effort; certain sets may not be converted.
616  AliasedResourceMap resourceMap =
617  collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
618  for (const auto &dr : resourceMap) {
619  const auto &resources = dr.second;
620  if (resources.size() == 1)
621  resources.front()->removeAttr("aliased");
622  }
623 }
624 } // namespace
625 
626 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
628  return std::make_unique<UnifyAliasedResourcePass>(std::move(getTargetEnv));
629 }
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp)
Collects all aliased resources in the given SPIR-V moduleOp.
static Type getRuntimeArrayElementType(Type type)
Returns the element type if the given type is a runtime array resource: !spirv.ptr<!...
static bool areSameBitwidthScalarType(Type a, Type b)
static std::optional< int > deduceCanonicalResource(ArrayRef< spirv::SPIRVType > types)
Given a list of resource element types, returns the index of the canonical resource that all resource...
std::pair< uint32_t, uint32_t > Descriptor
ConvertAliasResource(const ResourceAliasAnalysis &analysis, MLIRContext *context, PatternBenefit benefit=1)
const ResourceAliasAnalysis & analysis
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
IntegerType getI32Type()
Definition: Builders.cpp:107
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
std::optional< int64_t > getSizeInBytes()
Returns the size in bytes for each type.
Definition: SPIRVTypes.cpp:705
An attribute that specifies the target version, allowed extensions and capabilities,...
Vendor getVendorID() const
Returns the vendor ID.
ClientAPI getClientAPI() const
Returns the client API.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
RewritePatternSet & patterns
Definition: Patterns.h:74
std::unique_ptr< OperationPass< spirv::ModuleOp > > createUnifyAliasedResourcePass(GetTargetEnvFn getTargetEnv=nullptr)
std::function< spirv::TargetEnvAttr(spirv::ModuleOp)> GetTargetEnvFn
Creates an operation pass that unifies access of multiple aliased resources into access of one single...
Definition: Passes.h:36
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
LogicalResult matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override