MLIR  22.0.0git
PtrDialect.cpp
Go to the documentation of this file.
1 //===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Pointer dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Matchers.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 using namespace mlir;
22 using namespace mlir::ptr;
23 
24 //===----------------------------------------------------------------------===//
25 // Pointer dialect
26 //===----------------------------------------------------------------------===//
27 
28 void PtrDialect::initialize() {
29  addOperations<
30 #define GET_OP_LIST
31 #include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
32  >();
33  addAttributes<
34 #define GET_ATTRDEF_LIST
35 #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
36  >();
37  addTypes<
38 #define GET_TYPEDEF_LIST
39 #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
40  >();
41 }
42 
43 //===----------------------------------------------------------------------===//
44 // Common helper functions.
45 //===----------------------------------------------------------------------===//
46 
47 /// Verifies that the alignment attribute is a power of 2 if present.
48 static LogicalResult
49 verifyAlignment(std::optional<int64_t> alignment,
51  if (!alignment)
52  return success();
53  if (alignment.value() <= 0)
54  return emitError() << "alignment must be positive";
55  if (!llvm::isPowerOf2_64(alignment.value()))
56  return emitError() << "alignment must be a power of 2";
57  return success();
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // ConstantOp
62 //===----------------------------------------------------------------------===//
63 
64 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
65 
66 //===----------------------------------------------------------------------===//
67 // FromPtrOp
68 //===----------------------------------------------------------------------===//
69 
70 OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
71  // Fold the pattern:
72  // %ptr = ptr.to_ptr %v : type -> ptr
73  // (%mda = ptr.get_metadata %v : type)?
74  // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
75  // To:
76  // %val -> %v
77  Value ptrLike;
78  FromPtrOp fromPtr = *this;
79  while (fromPtr != nullptr) {
80  auto toPtr = fromPtr.getPtr().getDefiningOp<ToPtrOp>();
81  // Cannot fold if it's not a `to_ptr` op or the initial and final types are
82  // different.
83  if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
84  return ptrLike;
85  Value md = fromPtr.getMetadata();
86  // If the type has trivial metadata fold.
87  if (!fromPtr.getType().hasPtrMetadata()) {
88  ptrLike = toPtr.getPtr();
89  } else if (md) {
90  // Fold if the metadata can be verified to be equal.
91  if (auto mdOp = md.getDefiningOp<GetMetadataOp>();
92  mdOp && mdOp.getPtr() == toPtr.getPtr())
93  ptrLike = toPtr.getPtr();
94  }
95  // Check for a sequence of casts.
96  fromPtr = ptrLike ? ptrLike.getDefiningOp<FromPtrOp>() : nullptr;
97  }
98  return ptrLike;
99 }
100 
101 LogicalResult FromPtrOp::verify() {
102  if (isa<PtrType>(getType()))
103  return emitError() << "the result type cannot be `!ptr.ptr`";
104  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
105  return emitError()
106  << "expected the input and output to have the same memory space";
107  }
108  return success();
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // GatherOp
113 //===----------------------------------------------------------------------===//
114 
115 void GatherOp::getEffects(
117  &effects) {
118  // Gather performs reads from multiple memory locations specified by ptrs
119  effects.emplace_back(MemoryEffects::Read::get(), &getPtrsMutable());
120 }
121 
122 LogicalResult GatherOp::verify() {
123  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
124 
125  // Verify that the pointer type's memory space allows loads.
126  MemorySpaceAttrInterface ms =
127  cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
128  DataLayout dataLayout = DataLayout::closest(*this);
129  if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
130  getAlignment(), &dataLayout, emitDiag))
131  return failure();
132 
133  // Verify the alignment.
134  return verifyAlignment(getAlignment(), emitDiag);
135 }
136 
137 void GatherOp::build(OpBuilder &builder, OperationState &state, Type resultType,
138  Value ptrs, Value mask, Value passthrough,
139  unsigned alignment) {
140  build(builder, state, resultType, ptrs, mask, passthrough,
141  alignment ? std::optional<int64_t>(alignment) : std::nullopt);
142 }
143 
144 //===----------------------------------------------------------------------===//
145 // LoadOp
146 //===----------------------------------------------------------------------===//
147 
148 /// Verifies the attributes and the type of atomic memory access operations.
149 template <typename OpTy>
150 static LogicalResult
151 verifyAtomicMemOp(OpTy memOp, ArrayRef<AtomicOrdering> unsupportedOrderings) {
152  if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
153  if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
154  return memOp.emitOpError("unsupported ordering '")
155  << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
156  if (!memOp.getAlignment())
157  return memOp.emitOpError("expected alignment for atomic access");
158  return success();
159  }
160  if (memOp.getSyncscope()) {
161  return memOp.emitOpError(
162  "expected syncscope to be null for non-atomic access");
163  }
164  return success();
165 }
166 
167 void LoadOp::getEffects(
169  &effects) {
170  effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
171  // Volatile operations can have target-specific read-write effects on
172  // memory besides the one referred to by the pointer operand.
173  // Similarly, atomic operations that are monotonic or stricter cause
174  // synchronization that from a language point-of-view, are arbitrary
175  // read-writes into memory.
176  if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
177  getOrdering() != AtomicOrdering::unordered)) {
178  effects.emplace_back(MemoryEffects::Write::get());
179  effects.emplace_back(MemoryEffects::Read::get());
180  }
181 }
182 
183 LogicalResult LoadOp::verify() {
184  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
185  MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
186  DataLayout dataLayout = DataLayout::closest(*this);
187  if (!ms.isValidLoad(getResult().getType(), getOrdering(), getAlignment(),
188  &dataLayout, emitDiag))
189  return failure();
190  if (failed(verifyAlignment(getAlignment(), emitDiag)))
191  return failure();
192  return verifyAtomicMemOp(*this,
193  {AtomicOrdering::release, AtomicOrdering::acq_rel});
194 }
195 
196 void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
197  Value addr, unsigned alignment, bool isVolatile,
198  bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
199  AtomicOrdering ordering, StringRef syncscope) {
200  build(builder, state, type, addr,
201  alignment ? std::optional<int64_t>(alignment) : std::nullopt,
202  isVolatile, isNonTemporal, isInvariant, isInvariantGroup, ordering,
203  syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
204 }
205 //===----------------------------------------------------------------------===//
206 // MaskedLoadOp
207 //===----------------------------------------------------------------------===//
208 
209 void MaskedLoadOp::getEffects(
211  &effects) {
212  // MaskedLoad performs reads from the memory location specified by ptr.
213  effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
214 }
215 
216 LogicalResult MaskedLoadOp::verify() {
217  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
218  // Verify that the pointer type's memory space allows loads.
219  MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
220  DataLayout dataLayout = DataLayout::closest(*this);
221  if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
222  getAlignment(), &dataLayout, emitDiag))
223  return failure();
224 
225  // Verify the alignment.
226  return verifyAlignment(getAlignment(), emitDiag);
227 }
228 
229 void MaskedLoadOp::build(OpBuilder &builder, OperationState &state,
230  Type resultType, Value ptr, Value mask,
231  Value passthrough, unsigned alignment) {
232  build(builder, state, resultType, ptr, mask, passthrough,
233  alignment ? std::optional<int64_t>(alignment) : std::nullopt);
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // MaskedStoreOp
238 //===----------------------------------------------------------------------===//
239 
240 void MaskedStoreOp::getEffects(
242  &effects) {
243  // MaskedStore performs writes to the memory location specified by ptr
244  effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
245 }
246 
247 LogicalResult MaskedStoreOp::verify() {
248  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
249  // Verify that the pointer type's memory space allows stores.
250  MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
251  DataLayout dataLayout = DataLayout::closest(*this);
252  if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
253  getAlignment(), &dataLayout, emitDiag))
254  return failure();
255 
256  // Verify the alignment.
257  return verifyAlignment(getAlignment(), emitDiag);
258 }
259 
260 void MaskedStoreOp::build(OpBuilder &builder, OperationState &state,
261  Value value, Value ptr, Value mask,
262  unsigned alignment) {
263  build(builder, state, value, ptr, mask,
264  alignment ? std::optional<int64_t>(alignment) : std::nullopt);
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // ScatterOp
269 //===----------------------------------------------------------------------===//
270 
271 void ScatterOp::getEffects(
273  &effects) {
274  // Scatter performs writes to multiple memory locations specified by ptrs
275  effects.emplace_back(MemoryEffects::Write::get(), &getPtrsMutable());
276 }
277 
278 LogicalResult ScatterOp::verify() {
279  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
280 
281  // Verify that the pointer type's memory space allows stores.
282  MemorySpaceAttrInterface ms =
283  cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
284  DataLayout dataLayout = DataLayout::closest(*this);
285  if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
286  getAlignment(), &dataLayout, emitDiag))
287  return failure();
288 
289  // Verify the alignment.
290  return verifyAlignment(getAlignment(), emitDiag);
291 }
292 
293 void ScatterOp::build(OpBuilder &builder, OperationState &state, Value value,
294  Value ptrs, Value mask, unsigned alignment) {
295  build(builder, state, value, ptrs, mask,
296  alignment ? std::optional<int64_t>(alignment) : std::nullopt);
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // StoreOp
301 //===----------------------------------------------------------------------===//
302 
303 void StoreOp::getEffects(
305  &effects) {
306  effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
307  // Volatile operations can have target-specific read-write effects on
308  // memory besides the one referred to by the pointer operand.
309  // Similarly, atomic operations that are monotonic or stricter cause
310  // synchronization that from a language point-of-view, are arbitrary
311  // read-writes into memory.
312  if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
313  getOrdering() != AtomicOrdering::unordered)) {
314  effects.emplace_back(MemoryEffects::Write::get());
315  effects.emplace_back(MemoryEffects::Read::get());
316  }
317 }
318 
319 LogicalResult StoreOp::verify() {
320  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
321  MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
322  DataLayout dataLayout = DataLayout::closest(*this);
323  if (!ms.isValidStore(getValue().getType(), getOrdering(), getAlignment(),
324  &dataLayout, emitDiag))
325  return failure();
326  if (failed(verifyAlignment(getAlignment(), emitDiag)))
327  return failure();
328  return verifyAtomicMemOp(*this,
329  {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
330 }
331 
332 void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
333  Value addr, unsigned alignment, bool isVolatile,
334  bool isNonTemporal, bool isInvariantGroup,
335  AtomicOrdering ordering, StringRef syncscope) {
336  build(builder, state, value, addr,
337  alignment ? std::optional<int64_t>(alignment) : std::nullopt,
338  isVolatile, isNonTemporal, isInvariantGroup, ordering,
339  syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
340 }
341 
342 //===----------------------------------------------------------------------===//
343 // PtrAddOp
344 //===----------------------------------------------------------------------===//
345 
346 /// Fold: ptradd ptr + 0 -> ptr
347 OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
348  Attribute attr = adaptor.getOffset();
349  if (!attr)
350  return nullptr;
351  if (llvm::APInt value; m_ConstantInt(&value).match(attr) && value.isZero())
352  return getBase();
353  return nullptr;
354 }
355 
356 LogicalResult PtrAddOp::inferReturnTypes(
357  MLIRContext *context, std::optional<Location> location, ValueRange operands,
358  DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
359  SmallVectorImpl<Type> &inferredReturnTypes) {
360  // Get the base pointer and offset types.
361  Type baseType = operands[0].getType();
362  Type offsetType = operands[1].getType();
363 
364  auto offTy = dyn_cast<ShapedType>(offsetType);
365  if (!offTy) {
366  // If the offset isn't shaped, the result is always the base type.
367  inferredReturnTypes.push_back(baseType);
368  return success();
369  }
370  auto baseTy = dyn_cast<ShapedType>(baseType);
371  if (!baseTy) {
372  // Base isn't shaped, but offset is, use the ShapedType from offset with the
373  // base pointer as element type.
374  inferredReturnTypes.push_back(offTy.clone(baseType));
375  return success();
376  }
377 
378  // Both are shaped, their shape must match.
379  if (offTy.getShape() != baseTy.getShape()) {
380  if (location)
381  mlir::emitError(*location) << "shapes of base and offset must match";
382  return failure();
383  }
384 
385  // Make sure they are the same kind of shaped type.
386  if (baseType.getTypeID() != offsetType.getTypeID()) {
387  if (location)
388  mlir::emitError(*location) << "the shaped containers type must match";
389  return failure();
390  }
391  inferredReturnTypes.push_back(baseType);
392  return success();
393 }
394 
395 //===----------------------------------------------------------------------===//
396 // PtrDiffOp
397 //===----------------------------------------------------------------------===//
398 
399 LogicalResult PtrDiffOp::verify() {
400  // If the operands are not shaped early exit.
401  if (!isa<ShapedType>(getLhs().getType()))
402  return success();
403 
404  // Just check the container type matches, `SameOperandsAndResultShape` handles
405  // the actual shape.
406  if (getResult().getType().getTypeID() != getLhs().getType().getTypeID()) {
407  return emitError() << "expected the result to have the same container "
408  "type as the operands when operands are shaped";
409  }
410 
411  return success();
412 }
413 
414 ptr::PtrType PtrDiffOp::getPtrType() {
415  Type lhsType = getLhs().getType();
416  if (auto shapedType = dyn_cast<ShapedType>(lhsType))
417  return cast<ptr::PtrType>(shapedType.getElementType());
418  return cast<ptr::PtrType>(lhsType);
419 }
420 
421 Type PtrDiffOp::getIntType() {
422  Type resultType = getResult().getType();
423  if (auto shapedType = dyn_cast<ShapedType>(resultType))
424  return shapedType.getElementType();
425  return resultType;
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // ToPtrOp
430 //===----------------------------------------------------------------------===//
431 
432 OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
433  // Fold the pattern:
434  // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
435  // %ptr = ptr.to_ptr %val : type -> ptr
436  // To:
437  // %ptr -> %p
438  Value ptr;
439  ToPtrOp toPtr = *this;
440  while (toPtr != nullptr) {
441  auto fromPtr = toPtr.getPtr().getDefiningOp<FromPtrOp>();
442  // Cannot fold if it's not a `from_ptr` op.
443  if (!fromPtr)
444  return ptr;
445  ptr = fromPtr.getPtr();
446  // Check for chains of casts.
447  toPtr = ptr.getDefiningOp<ToPtrOp>();
448  }
449  return ptr;
450 }
451 
452 LogicalResult ToPtrOp::verify() {
453  if (isa<PtrType>(getPtr().getType()))
454  return emitError() << "the input value cannot be of type `!ptr.ptr`";
455  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
456  return emitError()
457  << "expected the input and output to have the same memory space";
458  }
459  return success();
460 }
461 
462 //===----------------------------------------------------------------------===//
463 // TypeOffsetOp
464 //===----------------------------------------------------------------------===//
465 
466 llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
467  if (layout)
468  return layout->getTypeSize(getElementType());
469  DataLayout dl = DataLayout::closest(*this);
470  return dl.getTypeSize(getElementType());
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // Pointer API.
475 //===----------------------------------------------------------------------===//
476 
477 #include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
478 
479 #define GET_ATTRDEF_CLASSES
480 #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
481 
482 #include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc"
483 
484 #define GET_TYPEDEF_CLASSES
485 #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
486 
487 #define GET_OP_CLASSES
488 #include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyAtomicMemOp(OpTy memOp, ArrayRef< AtomicOrdering > unsupportedOrderings)
Verifies the attributes and the type of atomic memory access operations.
Definition: PtrDialect.cpp:151
static LogicalResult verifyAlignment(std::optional< int64_t > alignment, function_ref< InFlightDiagnostic()> emitError)
Verifies that the alignment attribute is a power of 2 if present.
Definition: PtrDialect.cpp:49
Attributes are known-constant values of operations.
Definition: Attributes.h:25
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:262
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:316
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
TypeID getTypeID()
Return a unique identifier for the concrete type.
Definition: Types.h:101
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.