15 #include "llvm/Support/Debug.h"
17 #define DEBUG_TYPE "xegpu"
25 for (
size_t i = 0; i < trans.size(); i++)
26 shape[i] = old[trans[i]];
30 static std::string
makeString(T array,
bool breakline =
false) {
33 llvm::raw_string_ostream os(buf);
35 for (
size_t i = 1; i < array.size(); i++) {
36 os << array[i - 1] <<
", ";
40 os << array.back() <<
"]";
46 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
55 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
63 auto kind = attr.getValue();
64 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
65 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
71 auto kind = attr.getValue();
72 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
73 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
82 [[maybe_unused]]
auto ty = source.getType();
83 assert(ty.hasStaticShape() && offsets.size() == (
size_t)ty.getRank());
89 build(builder, state, tdesc, source, dynamicOffsets ,
96 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
97 Type tdesc, TypedValue<MemRefType> source,
101 assert(shape.size() && offsets.size() && strides.size() &&
102 shape.size() == strides.size() && shape.size() == offsets.size());
115 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
116 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
117 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
119 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
120 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
123 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
124 Type tdesc, TypedValue<IntegerType> source,
128 assert(shape.size() && offsets.size() && strides.size() &&
129 shape.size() == strides.size() && shape.size() == offsets.size());
142 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
143 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
144 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
146 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
147 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
151 auto rank = (int64_t)getMixedOffsets().size();
152 bool invalidRank =
false;
153 bool invalidElemTy =
false;
159 auto srcMemorySpace = getSourceMemorySpace();
160 auto tdescMemorySpace =
static_cast<unsigned>(
getType().getMemorySpace());
161 if (srcMemorySpace != tdescMemorySpace)
162 return emitOpError(
"Memory space mismatch.")
163 <<
" Source: " << srcMemorySpace
164 <<
", TensorDesc: " << tdescMemorySpace;
168 auto memrefTy = dyn_cast<MemRefType>(getSourceType());
170 invalidRank |= (memrefTy.getRank() != rank);
179 "Expecting the rank of shape, strides, offsets, and source (if source "
180 "is a memref) should match with each other.");
183 invalidRank = (
getType().getRank() > 2 ||
getType().getRank() > rank);
187 "Expecting the TensorDesc rank is up to 2 and not greater than the "
188 "ranks of shape, strides, offsets or the memref source.");
191 return emitOpError(
"TensorDesc should have the same element "
192 "type with the source if it is a memref.\n");
195 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
197 if (
getType().getRank() == 2 &&
198 tdescMemorySpace ==
static_cast<unsigned>(MemorySpace::SLM))
199 return emitOpError(
"SLM is not supported for 2D Block TensorDesc.\n");
208 auto tdescTy = getTensorDescType();
209 if (tdescTy.isScattered())
210 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
213 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
216 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
219 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
228 auto tdescTy = getTensorDescType();
231 if (tdescTy.getRank() > 2)
232 return emitOpError(
"Expecting a 1D/2D TensorDesc.\n");
234 if (tdescTy.isScattered())
235 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
238 return emitOpError(
"Invalid result, it should be a VectorType.\n");
241 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
244 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
247 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
249 auto array_len = tdescTy.getArrayLength();
253 if (getTranspose()) {
254 auto trans = getTranspose().value();
257 bool valid = std::all_of(trans.begin(), trans.end(), [&](
int t) {
258 return t >= 0 && t < tdescTy.getRank();
264 emitWarning(
"Invalid transpose attr. It is ignored.");
268 if (tdescTy.getRank() == 2) {
270 auto vnni_factor = valueShape.back();
271 tdescShape[axis] /= vnni_factor;
272 tdescShape.push_back(vnni_factor);
274 emitWarning(
"Invalid Packed Attr. It is ignored (available for 2D "
275 "TensorDesc only).");
280 auto it = tdescShape.begin();
281 tdescShape.insert(it, array_len);
284 if (tdescShape != valueShape)
285 return emitOpError() <<
"Result shape doesn't match TensorDesc shape."
286 <<
"The expected shape is " <<
makeString(tdescShape)
287 <<
". But the given shape is "
296 auto dstTy = getTensorDescType();
297 auto valTy = getValueType();
299 if (dstTy.getRank() > 2)
300 return emitOpError(
"Expecting a 1D/2D TensorDesc.\n");
302 if (dstTy.isScattered())
303 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
306 return emitOpError(
"Exepcting a VectorType result.\n");
309 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
312 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
315 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
324 auto ty = getTensorDescType();
325 if (ty.isScattered())
326 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
329 if (ty.getRank() != (int64_t)getNumOffsets()) {
330 return emitOpError(
"Invalid number of offsets.");
339 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
340 TensorDescType TensorDesc, Value source,
342 auto loc = source.getLoc();
343 int64_t size =
static_cast<int64_t
>(offsets.size());
346 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
347 build(builder, state, TensorDesc, source, offset);
350 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
351 TensorDescType TensorDesc, Value source,
354 build(builder, state, TensorDesc, source, ofrs);
358 auto tdescTy = getTensorDescType();
362 "Expecting the source is a 1D memref or pointer (uint64_t).");
364 if (!tdescTy.isScattered())
365 return emitOpError(
"Expects a scattered TensorDesc.\n");
371 auto srcMemorySpace = getSourceMemorySpace();
372 auto tdescMemorySpace =
static_cast<unsigned>(tdescTy.getMemorySpace());
373 if (srcMemorySpace != tdescMemorySpace)
374 return emitOpError(
"Memory space mismatch.")
375 <<
" Source: " << srcMemorySpace
376 <<
", TensorDesc: " << tdescMemorySpace;
378 auto chunkSize = tdescTy.getChunkSize();
382 16, 32, 64, 128, 256};
383 if (!llvm::is_contained(supportedChunkSizes, chunkSize))
384 return emitOpError(
"Invalid chunk_size. Supported values are 1, 2, 3, 4, "
385 "8, 16, 32, 64, 128, or 256.");
388 auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
389 auto bitsPerLane = elemBits * chunkSize;
390 if (chunkSize > 1 && bitsPerLane % 32) {
397 "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
400 auto lscConstraints = 512 * 8;
401 if (elemBits * tdescTy.getNumElements() > lscConstraints)
402 return emitOpError(
"total access size (simd_lanes * chunk_size * "
403 "sizeof(elemTy)) is upto 512 bytes.");
405 SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
407 shape.push_back(chunkSize);
410 if (shape != tdescShape)
411 return emitOpError(
"Incorrect TensorDesc shape. ")
412 <<
"Expected is " <<
makeString(shape) <<
"\n";
421 auto tdescTy = getTensorDescType();
422 if (!tdescTy.isScattered())
423 return emitOpError(
"Expects a scattered TensorDesc.\n");
426 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
429 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
432 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
441 auto tdescTy = getTensorDescType();
442 auto maskTy = getMaskType();
443 auto valueTy = getValueType();
445 if (!tdescTy.isScattered())
446 return emitOpError(
"Expects a scattered TensorDesc.\n");
449 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
452 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
455 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
457 auto tdescElemTy = tdescTy.getElementType();
459 if (tdescElemTy != valueElemTy)
461 "Value should have the same element type as TensorDesc.");
467 if (tdescShape[0] != maskShape[0])
468 return emitOpError(
"dim-0 of the Mask and TensorDesc should be the same.");
470 if (tdescTy.getRank() == 2) {
471 if (!getTransposeAttr())
472 return emitOpError(
"load_gather has to be transposed.");
476 if (valueShape != tdescShape)
477 return emitOpError(
"Unexpected result shape")
478 <<
"(Expected shape: " <<
makeString(tdescShape)
479 <<
", Given shape: " <<
makeString(valueShape) <<
").\n";
488 auto tdescTy = getTensorDescType();
489 if (!tdescTy.isScattered())
490 return emitOpError(
"Expects a scattered TensorDesc.\n");
493 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
496 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
499 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
501 auto maskTy = getMaskType();
502 auto valueTy = getValueType();
506 if (tdescShape[0] != maskShape[0])
507 return emitOpError(
"dim-0 of the Mask and TensorDesc should be the same.");
509 if (tdescTy.getRank() == 2) {
510 if (!getTransposeAttr())
511 return emitOpError(
"load_gather has to be transposed.");
515 if (valueShape != tdescShape)
516 return emitOpError(
"Unexpected value shape")
517 <<
"(Expected shape: " <<
makeString(tdescShape)
518 <<
", Given shape: " <<
makeString(valueShape) <<
").\n";
526 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
529 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.
getType());
530 assert(tdescTy &&
"Expecting the source is a TensorDescType value.");
531 auto loc = tensorDesc.
getLoc();
532 int64_t size =
static_cast<int64_t
>(offsets.size());
535 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
536 build(builder, state, tdescTy, tensorDesc, offset);
539 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
542 build(builder, state, tensorDesc, ofrs);
549 int64_t lhsRank = getLhsType().getRank();
550 int64_t rhsRank = getRhsType().getRank();
552 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
553 return emitOpError(
"expecting lhs to be a 2D vector, and rhs to be either "
554 "2D or 3D (packed) vector.");
556 auto lhsShape = getLhsType().getShape();
557 auto rhsShape = getRhsType().getShape();
558 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
559 if (bK != lhsShape[1])
560 return emitOpError(
"K-dimension mismatch.");
568 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
569 #define GET_OP_CLASSES
570 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static std::string makeString(T array, bool breakline=false)
static int64_t getRankOf(Value val)
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
static bool isReadHintOrNone(const CachePolicyAttr &attr)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
static SmallVector< int64_t > getShapeOf(Type type)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.