MLIR  22.0.0git
PtrDialect.cpp
Go to the documentation of this file.
1 //===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Pointer dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Matchers.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir;
21 using namespace mlir::ptr;
22 
23 //===----------------------------------------------------------------------===//
24 // Pointer dialect
25 //===----------------------------------------------------------------------===//
26 
27 void PtrDialect::initialize() {
28  addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
31  >();
32  addAttributes<
33 #define GET_ATTRDEF_LIST
34 #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
35  >();
36  addTypes<
37 #define GET_TYPEDEF_LIST
38 #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
39  >();
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Common helper functions.
44 //===----------------------------------------------------------------------===//
45 
46 /// Verifies that the alignment attribute is a power of 2 if present.
47 static LogicalResult
48 verifyAlignment(std::optional<int64_t> alignment,
50  if (!alignment)
51  return success();
52  if (alignment.value() <= 0)
53  return emitError() << "alignment must be positive";
54  if (!llvm::isPowerOf2_64(alignment.value()))
55  return emitError() << "alignment must be a power of 2";
56  return success();
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // ConstantOp
61 //===----------------------------------------------------------------------===//
62 
63 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
64 
65 //===----------------------------------------------------------------------===//
66 // FromPtrOp
67 //===----------------------------------------------------------------------===//
68 
69 OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
70  // Fold the pattern:
71  // %ptr = ptr.to_ptr %v : type -> ptr
72  // (%mda = ptr.get_metadata %v : type)?
73  // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
74  // To:
75  // %val -> %v
76  Value ptrLike;
77  FromPtrOp fromPtr = *this;
78  while (fromPtr != nullptr) {
79  auto toPtr = fromPtr.getPtr().getDefiningOp<ToPtrOp>();
80  // Cannot fold if it's not a `to_ptr` op or the initial and final types are
81  // different.
82  if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
83  return ptrLike;
84  Value md = fromPtr.getMetadata();
85  // If the type has trivial metadata fold.
86  if (!fromPtr.getType().hasPtrMetadata()) {
87  ptrLike = toPtr.getPtr();
88  } else if (md) {
89  // Fold if the metadata can be verified to be equal.
90  if (auto mdOp = md.getDefiningOp<GetMetadataOp>();
91  mdOp && mdOp.getPtr() == toPtr.getPtr())
92  ptrLike = toPtr.getPtr();
93  }
94  // Check for a sequence of casts.
95  fromPtr = ptrLike ? ptrLike.getDefiningOp<FromPtrOp>() : nullptr;
96  }
97  return ptrLike;
98 }
99 
100 LogicalResult FromPtrOp::verify() {
101  if (isa<PtrType>(getType()))
102  return emitError() << "the result type cannot be `!ptr.ptr`";
103  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
104  return emitError()
105  << "expected the input and output to have the same memory space";
106  }
107  return success();
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // GatherOp
112 //===----------------------------------------------------------------------===//
113 
114 void GatherOp::getEffects(
116  &effects) {
117  // Gather performs reads from multiple memory locations specified by ptrs
118  effects.emplace_back(MemoryEffects::Read::get(), &getPtrsMutable());
119 }
120 
121 LogicalResult GatherOp::verify() {
122  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
123 
124  // Verify that the pointer type's memory space allows loads.
125  MemorySpaceAttrInterface ms =
126  cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
127  DataLayout dataLayout = DataLayout::closest(*this);
128  if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
129  getAlignment(), &dataLayout, emitDiag))
130  return failure();
131 
132  // Verify the alignment.
133  return verifyAlignment(getAlignment(), emitDiag);
134 }
135 
136 void GatherOp::build(OpBuilder &builder, OperationState &state, Type resultType,
137  Value ptrs, Value mask, Value passthrough,
138  unsigned alignment) {
139  build(builder, state, resultType, ptrs, mask, passthrough,
140  alignment ? std::optional<int64_t>(alignment) : std::nullopt);
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // LoadOp
145 //===----------------------------------------------------------------------===//
146 
147 /// Verifies the attributes and the type of atomic memory access operations.
148 template <typename OpTy>
149 static LogicalResult
150 verifyAtomicMemOp(OpTy memOp, ArrayRef<AtomicOrdering> unsupportedOrderings) {
151  if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
152  if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
153  return memOp.emitOpError("unsupported ordering '")
154  << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
155  if (!memOp.getAlignment())
156  return memOp.emitOpError("expected alignment for atomic access");
157  return success();
158  }
159  if (memOp.getSyncscope()) {
160  return memOp.emitOpError(
161  "expected syncscope to be null for non-atomic access");
162  }
163  return success();
164 }
165 
166 void LoadOp::getEffects(
168  &effects) {
169  effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
170  // Volatile operations can have target-specific read-write effects on
171  // memory besides the one referred to by the pointer operand.
172  // Similarly, atomic operations that are monotonic or stricter cause
173  // synchronization that from a language point-of-view, are arbitrary
174  // read-writes into memory.
175  if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
176  getOrdering() != AtomicOrdering::unordered)) {
177  effects.emplace_back(MemoryEffects::Write::get());
178  effects.emplace_back(MemoryEffects::Read::get());
179  }
180 }
181 
182 LogicalResult LoadOp::verify() {
183  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
184  MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
185  DataLayout dataLayout = DataLayout::closest(*this);
186  if (!ms.isValidLoad(getResult().getType(), getOrdering(), getAlignment(),
187  &dataLayout, emitDiag))
188  return failure();
189  if (failed(verifyAlignment(getAlignment(), emitDiag)))
190  return failure();
191  return verifyAtomicMemOp(*this,
192  {AtomicOrdering::release, AtomicOrdering::acq_rel});
193 }
194 
195 void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
196  Value addr, unsigned alignment, bool isVolatile,
197  bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
198  AtomicOrdering ordering, StringRef syncscope) {
199  build(builder, state, type, addr,
200  alignment ? std::optional<int64_t>(alignment) : std::nullopt,
201  isVolatile, isNonTemporal, isInvariant, isInvariantGroup, ordering,
202  syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
203 }
204 //===----------------------------------------------------------------------===//
205 // MaskedLoadOp
206 //===----------------------------------------------------------------------===//
207 
208 void MaskedLoadOp::getEffects(
210  &effects) {
211  // MaskedLoad performs reads from the memory location specified by ptr.
212  effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
213 }
214 
215 LogicalResult MaskedLoadOp::verify() {
216  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
217  // Verify that the pointer type's memory space allows loads.
218  MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
219  DataLayout dataLayout = DataLayout::closest(*this);
220  if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
221  getAlignment(), &dataLayout, emitDiag))
222  return failure();
223 
224  // Verify the alignment.
225  return verifyAlignment(getAlignment(), emitDiag);
226 }
227 
228 void MaskedLoadOp::build(OpBuilder &builder, OperationState &state,
229  Type resultType, Value ptr, Value mask,
230  Value passthrough, unsigned alignment) {
231  build(builder, state, resultType, ptr, mask, passthrough,
232  alignment ? std::optional<int64_t>(alignment) : std::nullopt);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // MaskedStoreOp
237 //===----------------------------------------------------------------------===//
238 
239 void MaskedStoreOp::getEffects(
241  &effects) {
242  // MaskedStore performs writes to the memory location specified by ptr
243  effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
244 }
245 
246 LogicalResult MaskedStoreOp::verify() {
247  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
248  // Verify that the pointer type's memory space allows stores.
249  MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
250  DataLayout dataLayout = DataLayout::closest(*this);
251  if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
252  getAlignment(), &dataLayout, emitDiag))
253  return failure();
254 
255  // Verify the alignment.
256  return verifyAlignment(getAlignment(), emitDiag);
257 }
258 
259 void MaskedStoreOp::build(OpBuilder &builder, OperationState &state,
260  Value value, Value ptr, Value mask,
261  unsigned alignment) {
262  build(builder, state, value, ptr, mask,
263  alignment ? std::optional<int64_t>(alignment) : std::nullopt);
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // ScatterOp
268 //===----------------------------------------------------------------------===//
269 
270 void ScatterOp::getEffects(
272  &effects) {
273  // Scatter performs writes to multiple memory locations specified by ptrs
274  effects.emplace_back(MemoryEffects::Write::get(), &getPtrsMutable());
275 }
276 
277 LogicalResult ScatterOp::verify() {
278  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
279 
280  // Verify that the pointer type's memory space allows stores.
281  MemorySpaceAttrInterface ms =
282  cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
283  DataLayout dataLayout = DataLayout::closest(*this);
284  if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
285  getAlignment(), &dataLayout, emitDiag))
286  return failure();
287 
288  // Verify the alignment.
289  return verifyAlignment(getAlignment(), emitDiag);
290 }
291 
292 void ScatterOp::build(OpBuilder &builder, OperationState &state, Value value,
293  Value ptrs, Value mask, unsigned alignment) {
294  build(builder, state, value, ptrs, mask,
295  alignment ? std::optional<int64_t>(alignment) : std::nullopt);
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // StoreOp
300 //===----------------------------------------------------------------------===//
301 
302 void StoreOp::getEffects(
304  &effects) {
305  effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
306  // Volatile operations can have target-specific read-write effects on
307  // memory besides the one referred to by the pointer operand.
308  // Similarly, atomic operations that are monotonic or stricter cause
309  // synchronization that from a language point-of-view, are arbitrary
310  // read-writes into memory.
311  if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
312  getOrdering() != AtomicOrdering::unordered)) {
313  effects.emplace_back(MemoryEffects::Write::get());
314  effects.emplace_back(MemoryEffects::Read::get());
315  }
316 }
317 
318 LogicalResult StoreOp::verify() {
319  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
320  MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
321  DataLayout dataLayout = DataLayout::closest(*this);
322  if (!ms.isValidStore(getValue().getType(), getOrdering(), getAlignment(),
323  &dataLayout, emitDiag))
324  return failure();
325  if (failed(verifyAlignment(getAlignment(), emitDiag)))
326  return failure();
327  return verifyAtomicMemOp(*this,
328  {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
329 }
330 
331 void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
332  Value addr, unsigned alignment, bool isVolatile,
333  bool isNonTemporal, bool isInvariantGroup,
334  AtomicOrdering ordering, StringRef syncscope) {
335  build(builder, state, value, addr,
336  alignment ? std::optional<int64_t>(alignment) : std::nullopt,
337  isVolatile, isNonTemporal, isInvariantGroup, ordering,
338  syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // PtrAddOp
343 //===----------------------------------------------------------------------===//
344 
345 /// Fold: ptradd ptr + 0 -> ptr
346 OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
347  Attribute attr = adaptor.getOffset();
348  if (!attr)
349  return nullptr;
350  if (llvm::APInt value; m_ConstantInt(&value).match(attr) && value.isZero())
351  return getBase();
352  return nullptr;
353 }
354 
355 LogicalResult PtrAddOp::inferReturnTypes(
356  MLIRContext *context, std::optional<Location> location, ValueRange operands,
357  DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
358  SmallVectorImpl<Type> &inferredReturnTypes) {
359  // Get the base pointer and offset types.
360  Type baseType = operands[0].getType();
361  Type offsetType = operands[1].getType();
362 
363  auto offTy = dyn_cast<ShapedType>(offsetType);
364  if (!offTy) {
365  // If the offset isn't shaped, the result is always the base type.
366  inferredReturnTypes.push_back(baseType);
367  return success();
368  }
369  auto baseTy = dyn_cast<ShapedType>(baseType);
370  if (!baseTy) {
371  // Base isn't shaped, but offset is, use the ShapedType from offset with the
372  // base pointer as element type.
373  inferredReturnTypes.push_back(offTy.clone(baseType));
374  return success();
375  }
376 
377  // Both are shaped, their shape must match.
378  if (offTy.getShape() != baseTy.getShape()) {
379  if (location)
380  mlir::emitError(*location) << "shapes of base and offset must match";
381  return failure();
382  }
383 
384  // Make sure they are the same kind of shaped type.
385  if (baseType.getTypeID() != offsetType.getTypeID()) {
386  if (location)
387  mlir::emitError(*location) << "the shaped containers type must match";
388  return failure();
389  }
390  inferredReturnTypes.push_back(baseType);
391  return success();
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // ToPtrOp
396 //===----------------------------------------------------------------------===//
397 
398 OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
399  // Fold the pattern:
400  // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
401  // %ptr = ptr.to_ptr %val : type -> ptr
402  // To:
403  // %ptr -> %p
404  Value ptr;
405  ToPtrOp toPtr = *this;
406  while (toPtr != nullptr) {
407  auto fromPtr = toPtr.getPtr().getDefiningOp<FromPtrOp>();
408  // Cannot fold if it's not a `from_ptr` op.
409  if (!fromPtr)
410  return ptr;
411  ptr = fromPtr.getPtr();
412  // Check for chains of casts.
413  toPtr = ptr.getDefiningOp<ToPtrOp>();
414  }
415  return ptr;
416 }
417 
418 LogicalResult ToPtrOp::verify() {
419  if (isa<PtrType>(getPtr().getType()))
420  return emitError() << "the input value cannot be of type `!ptr.ptr`";
421  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
422  return emitError()
423  << "expected the input and output to have the same memory space";
424  }
425  return success();
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // TypeOffsetOp
430 //===----------------------------------------------------------------------===//
431 
432 llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
433  if (layout)
434  return layout->getTypeSize(getElementType());
435  DataLayout dl = DataLayout::closest(*this);
436  return dl.getTypeSize(getElementType());
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // Pointer API.
441 //===----------------------------------------------------------------------===//
442 
443 #include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
444 
445 #define GET_ATTRDEF_CLASSES
446 #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
447 
448 #include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc"
449 
450 #define GET_TYPEDEF_CLASSES
451 #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
452 
453 #define GET_OP_CLASSES
454 #include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyAtomicMemOp(OpTy memOp, ArrayRef< AtomicOrdering > unsupportedOrderings)
Verifies the attributes and the type of atomic memory access operations.
Definition: PtrDialect.cpp:150
static LogicalResult verifyAlignment(std::optional< int64_t > alignment, function_ref< InFlightDiagnostic()> emitError)
Verifies that the alignment attribute is a power of 2 if present.
Definition: PtrDialect.cpp:48
Attributes are known-constant values of operations.
Definition: Attributes.h:25
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
TypeID getTypeID()
Return a unique identifier for the concrete type.
Definition: Types.h:101
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.