16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "xegpu"
26 for (
size_t i = 0; i < trans.size(); i++)
27 shape[i] = old[trans[i]];
31 static std::string
makeString(T array,
bool breakline =
false) {
34 llvm::raw_string_ostream os(buf);
36 for (
size_t i = 1; i < array.size(); i++) {
37 os << array[i - 1] <<
", ";
41 os << array.back() <<
"]";
47 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
56 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED ||
kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING ||
kind == CachePolicy::READ_INVALIDATE;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED ||
kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK ||
kind == CachePolicy::WRITE_THROUGH;
79 TensorDescType tdescTy, UnitAttr transposeAttr,
82 if (!tdescTy.isScattered())
83 return emitError() <<
"Expects a scattered TensorDesc.";
86 return emitError() <<
"Expecting a vector type result.";
91 auto chunkSize = tdescTy.getChunkSize();
93 if (valueTy.getElementType() != tdescTy.getElementType())
95 <<
"Value should have the same element type as TensorDesc.";
97 if (tdescShape[0] != maskShape[0])
99 <<
"dim-0 of the Mask and TensorDesc should be the same.";
102 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
103 if (tdescTy.getLayoutAttr())
104 return emitError() <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
106 return emitError() <<
"doesn't need TransposeAttr for SIMT code";
110 if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
112 return emitError() <<
"rank-2 tensor has to be transposed.";
116 if (tdescShape != valueShape)
118 <<
" is neither a valid distribution for SIMT nor "
119 "consistent with the tensor descriptor for SIMD "
130 [[maybe_unused]]
auto ty = source.getType();
131 assert(ty.hasStaticShape() && offsets.size() == (
size_t)ty.getRank());
137 build(builder, state, tdesc, source, dynamicOffsets ,
144 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
145 Type tdesc, Value source,
149 assert(shape.size() && offsets.size() && strides.size() &&
150 shape.size() == strides.size() && shape.size() == offsets.size());
152 Type srcTy = source.getType();
153 assert(isa<IntegerType>(srcTy) ||
154 isa<MemRefType>(srcTy) &&
"Source has to be either int or memref.");
168 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
169 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
170 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
172 if (
auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
173 auto memrefShape = memrefTy.getShape();
174 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
178 if (staticShape == memrefShape && staticStrides == memrefStrides) {
184 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
185 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
189 auto rank = (int64_t)getMixedOffsets().size();
190 bool invalidRank =
false;
191 bool invalidElemTy =
false;
197 auto srcMemorySpace = getSourceMemorySpace();
198 auto tdescMemorySpace =
static_cast<unsigned>(
getType().getMemorySpace());
199 if (srcMemorySpace != tdescMemorySpace)
200 return emitOpError(
"Memory space mismatch.")
201 <<
" Source: " << srcMemorySpace
202 <<
", TensorDesc: " << tdescMemorySpace;
206 auto memrefTy = dyn_cast<MemRefType>(getSourceType());
208 invalidRank |= (memrefTy.getRank() != rank);
217 "Expecting the rank of shape, strides, offsets, and source (if source "
218 "is a memref) should match with each other.");
221 invalidRank = (
getType().getRank() > 2 ||
getType().getRank() > rank);
225 "Expecting the TensorDesc rank is up to 2 and not greater than the "
226 "ranks of shape, strides, offsets or the memref source.");
229 return emitOpError(
"TensorDesc should have the same element "
230 "type with the source if it is a memref.\n");
233 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
242 auto tdescTy = getTensorDescType();
243 if (tdescTy.isScattered())
244 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
247 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
250 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
253 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
262 auto tdescTy = getTensorDescType();
265 if (tdescTy.getRank() > 2)
266 return emitOpError(
"Expecting a 1D/2D TensorDesc.\n");
268 if (tdescTy.isScattered())
269 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
272 return emitOpError(
"Invalid result, it should be a VectorType.\n");
275 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
278 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
281 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
283 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
284 int valueElems = valueTy.getNumElements();
289 if (valueElems < tdescElems && valueTy.getRank() == 1) {
291 if (tdescTy.getLayoutAttr())
293 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
298 if (tdescElems % valueElems)
301 <<
" is not a valid distribution for tensor descriptor "
311 if (getTranspose()) {
312 auto trans = getTranspose().value();
315 bool valid = llvm::all_of(
316 trans, [&](
int t) {
return t >= 0 && t < tdescTy.getRank(); });
325 if (tdescTy.getRank() == 2) {
327 auto vnni_factor = valueShape.back();
328 tdescShape[axis] /= vnni_factor;
329 tdescShape.push_back(vnni_factor);
332 <<
"Invalid Packed Attr. It is ignored (available for 2D "
337 auto array_len = tdescTy.getArrayLength();
339 tdescShape.insert(tdescShape.begin(), array_len);
342 if (tdescShape != valueShape) {
343 return emitOpError() <<
"Result shape " <<
makeString(valueShape)
344 <<
" is not consistent with tensor descriptor "
355 auto dstTy = getTensorDescType();
356 auto valTy = getValueType();
358 if (dstTy.getRank() > 2)
359 return emitOpError(
"Expecting a 1D/2D TensorDesc.\n");
361 if (dstTy.isScattered())
362 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
365 return emitOpError(
"Expecting a VectorType result.\n");
368 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
371 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
374 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
376 auto array_len = dstTy.getArrayLength();
378 return emitOpError(
"array length is not supported by store_nd.\n");
380 auto tdescElems = dstTy.getNumElements();
381 auto valueElems = valTy.getNumElements();
386 if (valTy.getRank() == 1 && valueElems < tdescElems) {
388 if (dstTy.getLayoutAttr())
390 <<
"TensorDesc doesn't need LayoutAttr for SIMT code";
392 if (tdescElems % valueElems) {
395 <<
" is not a valid distribution for tensor descriptor " << dstTy;
403 if (tdescShape != valueShape) {
404 return emitOpError() <<
"Value shape " <<
makeString(valueShape)
405 <<
" is not consistent with tensor descriptor "
416 auto ty = getTensorDescType();
417 if (ty.isScattered())
418 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
421 if (ty.getRank() != (int64_t)getNumOffsets()) {
422 return emitOpError(
"Invalid number of offsets.");
431 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
432 TensorDescType TensorDesc, Value source,
434 auto loc = source.getLoc();
435 int64_t size =
static_cast<int64_t
>(offsets.size());
438 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
439 build(builder, state, TensorDesc, source, offset);
442 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
443 TensorDescType TensorDesc, Value source,
446 build(builder, state, TensorDesc, source, ofrs);
450 auto tdescTy = getTensorDescType();
454 "Expecting the source is a 1D memref or pointer (uint64_t).");
456 if (!tdescTy.isScattered())
457 return emitOpError(
"Expects a scattered TensorDesc.\n");
463 auto srcMemorySpace = getSourceMemorySpace();
464 auto tdescMemorySpace =
static_cast<unsigned>(tdescTy.getMemorySpace());
465 if (srcMemorySpace != tdescMemorySpace)
466 return emitOpError(
"Memory space mismatch.")
467 <<
" Source: " << srcMemorySpace
468 <<
", TensorDesc: " << tdescMemorySpace;
471 auto chunkSize = tdescTy.getChunkSize();
472 auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
473 auto bitsPerLane = elemBits * chunkSize;
474 if (chunkSize > 1 && bitsPerLane % 32) {
481 "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
484 auto lscConstraints = 512 * 8;
485 if (elemBits * tdescTy.getNumElements() > lscConstraints)
486 return emitOpError(
"total access size (simd_lanes * chunk_size * "
487 "sizeof(elemTy)) is upto 512 bytes.");
489 SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
491 shape.push_back(chunkSize);
494 if (shape != tdescShape)
495 return emitOpError(
"Incorrect TensorDesc shape. ")
496 <<
"Expected is " <<
makeString(shape) <<
"\n";
505 auto tdescTy = getTensorDescType();
506 if (!tdescTy.isScattered())
507 return emitOpError(
"Expects a scattered TensorDesc.\n");
510 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
513 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
516 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
525 auto tdescTy = getTensorDescType();
526 auto maskTy = getMaskType();
527 auto valueTy = getValueType();
530 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
533 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
536 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
540 [&]() {
return emitOpError(); });
547 auto tdescTy = getTensorDescType();
548 auto maskTy = getMaskType();
549 auto valueTy = getValueType();
552 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
555 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
558 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
562 [&]() {
return emitOpError(); });
568 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
571 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.
getType());
572 assert(tdescTy &&
"Expecting the source is a TensorDescType value.");
573 auto loc = tensorDesc.
getLoc();
574 int64_t size =
static_cast<int64_t
>(offsets.size());
577 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
578 build(builder, state, tdescTy, tensorDesc, offset);
581 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
584 build(builder, state, tensorDesc, ofrs);
591 int64_t lhsRank = getLhsType().getRank();
592 int64_t rhsRank = getRhsType().getRank();
593 int64_t resRank = getResultType().getRank();
594 auto lhsShape = getLhsType().getShape();
595 auto rhsShape = getRhsType().getShape();
596 auto resShape = getResultType().getShape();
598 if (getAcc() && getAcc().
getType() != getResultType())
599 return emitOpError(
"Expecting the acc type to be the same as result.");
604 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
605 auto numElems = getRhsType().getNumElements();
606 auto elemTy = getRhsType().getElementType();
607 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
608 if (numElems % factor != 0)
609 return emitOpError(
"Expecting B operand to be a multiple of 32 bits.");
614 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
616 "expecting lhs and result to be a 2D vector, and rhs to be either "
617 "2D or 3D (packed) vector.");
618 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
619 if (bK != lhsShape[1])
620 return emitOpError(
"K-dimension mismatch.");
621 if (lhsShape[0] != resShape[0])
622 return emitOpError(
"M-dimension mismatch.");
623 if (rhsShape[1] != resShape[1])
624 return emitOpError(
"N-dimension mismatch.");
633 auto srcMap = getSrcMapAttr();
634 auto resMap = getResMapAttr();
636 return emitOpError(
"expected srcMap.");
638 return emitOpError(
"expected resMap.");
640 if (srcMap == resMap)
641 return emitOpError(
"expected different srcMap and resMap.");
644 if ((!srcMap.isWgLayout() || !resMap.isWgLayout()) &&
645 (!srcMap.isSgLayout() || !resMap.isSgLayout()))
647 "expected srcMap and resMap be WgLayout or SgLayout at the same time.");
649 auto shape = getSource().getType().getShape();
650 if (!XeGPUDialect::isEvenlyDistributable(shape, srcMap))
651 return emitOpError(
"invalid srcMap, data cannot be evenly distributed.");
653 if (!XeGPUDialect::isEvenlyDistributable(shape, resMap))
654 return emitOpError(
"invalid resMap, data cannot be evenly distributed.");
656 return mlir::success();
662 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
663 #define GET_OP_CLASSES
664 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This class represents a diagnostic that is inflight and set to be reported.
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
@ Type
An inlay hint that for a type annotation.
static std::string makeString(T array, bool breakline=false)
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, UnitAttr transposeAttr, function_ref< InFlightDiagnostic()> emitError)
static int64_t getRankOf(Value val)
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
static bool isReadHintOrNone(const CachePolicyAttr &attr)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
static SmallVector< int64_t > getShapeOf(Type type)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.